Make CT required in IClientConfigurationStore, IDynamicClientRegistrationResponseGenerator, and ISessionClaimsFilter, flow through implementations, callers, and tests (Wave 30 — final wave)

This commit is contained in:
Damian Hickey 2026-02-21 09:11:27 +01:00
parent b669c9f62f
commit a115f5b81b
11 changed files with 41 additions and 31 deletions

View file

@ -9,7 +9,7 @@ namespace Duende.IdentityServer.AspNetIdentity;
public class DefaultSessionClaimsFilter : ISessionClaimsFilter
{
/// <inheritdoc/>
public Task<IReadOnlyCollection<Claim>> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context)
public Task<IReadOnlyCollection<Claim>> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context, CT ct)
{
var newClaimTypes = context.NewPrincipal.Claims.Select(x => x.Type).ToArray();
var currentClaimsToKeep = context.CurrentPrincipal.Claims.Where(x => !newClaimTypes.Contains(x.Type)).ToArray();

View file

@ -16,6 +16,7 @@ public interface ISessionClaimsFilter
/// </summary>
/// <param name="context">The SecurityStampRefreshingPrincipalContext <see cref="SecurityStampRefreshingPrincipalContext.SecurityStampRefreshingPrincipalContext"/>
/// in the call to <see cref="SecurityStampValidatorOptions.OnRefreshingPrincipal"/>.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>The claims of the ClaimsPrincipal which should be persisted for the session.</returns>
public Task<IReadOnlyCollection<Claim>> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context);
public Task<IReadOnlyCollection<Claim>> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context, CT ct);
}

View file

@ -26,7 +26,7 @@ public static class SecurityStampValidatorCallback
return;
}
var currentClaimsToKeep = await sessionClaimsFilter.FilterToSessionClaimsAsync(context);
var currentClaimsToKeep = await sessionClaimsFilter.FilterToSessionClaimsAsync(context, default);
var id = context.NewPrincipal.Identities.First();
id.AddClaims(currentClaimsToKeep);

View file

@ -46,10 +46,10 @@ public class ClientConfigurationStore : IClientConfigurationStore
}
/// <inheritdoc />
public async Task AddAsync(Client client)
public async Task AddAsync(Client client, CT ct)
{
Logger.LogDebug("Adding client {ClientId} to configuration store", client.ClientId);
DbContext.Clients.Add(client.ToEntity());
await DbContext.SaveChangesAsync(CancellationTokenProvider.CancellationToken);
await DbContext.SaveChangesAsync(ct);
}
}

View file

@ -47,7 +47,7 @@ public class DynamicClientRegistrationEndpoint
// Check content type
if (!HasCorrectContentType(httpContext.Request))
{
await _responseGenerator.WriteContentTypeError(httpContext);
await _responseGenerator.WriteContentTypeError(httpContext, httpContext.RequestAborted);
return;
}
@ -55,7 +55,7 @@ public class DynamicClientRegistrationEndpoint
var request = await TryParseAsync(httpContext.Request);
if (request == null)
{
await _responseGenerator.WriteBadRequestError(httpContext);
await _responseGenerator.WriteBadRequestError(httpContext, httpContext.RequestAborted);
return;
}
@ -66,18 +66,18 @@ public class DynamicClientRegistrationEndpoint
if (validationResult is DynamicClientRegistrationError validationError)
{
await _responseGenerator.WriteError(httpContext, validationError);
await _responseGenerator.WriteError(httpContext, validationError, httpContext.RequestAborted);
}
else
{
var processingResult = await _processor.ProcessAsync(dcrContext, httpContext.RequestAborted);
if (processingResult is DynamicClientRegistrationError processingFailure)
{
await _responseGenerator.WriteError(httpContext, processingFailure);
await _responseGenerator.WriteError(httpContext, processingFailure, httpContext.RequestAborted);
}
else if (processingResult is DynamicClientRegistrationResponse success)
{
await _responseGenerator.WriteSuccessResponse(httpContext, success);
await _responseGenerator.WriteSuccessResponse(httpContext, success, httpContext.RequestAborted);
}
else
{

View file

@ -64,7 +64,7 @@ public class DynamicClientRegistrationRequestProcessor : IDynamicClientRegistrat
}
}
await Store.AddAsync(context.Client);
await Store.AddAsync(context.Client, ct);
return new DynamicClientRegistrationResponse(context.Request, context.Client)
{

View file

@ -31,15 +31,15 @@ public class DynamicClientRegistrationResponseGenerator : IDynamicClientRegistra
public DynamicClientRegistrationResponseGenerator(ILogger<DynamicClientRegistrationResponseGenerator> logger) => Logger = logger;
/// <inheritdoc/>
public virtual async Task WriteResponse<T>(HttpContext context, int statusCode, T response)
public virtual async Task WriteResponse<T>(HttpContext context, int statusCode, T response, CT ct)
where T : IDynamicClientRegistrationResponse
{
context.Response.StatusCode = statusCode;
await context.Response.WriteAsJsonAsync(response, SerializerOptions);
await context.Response.WriteAsJsonAsync(response, SerializerOptions, ct);
}
/// <inheritdoc/>
public virtual Task WriteContentTypeError(HttpContext context)
public virtual Task WriteContentTypeError(HttpContext context, CT ct)
{
Logger.LogDebug("Invalid content type in dynamic client registration request");
context.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType;
@ -47,19 +47,20 @@ public class DynamicClientRegistrationResponseGenerator : IDynamicClientRegistra
}
/// <inheritdoc/>
public virtual async Task WriteBadRequestError(HttpContext context) =>
public virtual async Task WriteBadRequestError(HttpContext context, CT ct) =>
await WriteResponse(context, StatusCodes.Status400BadRequest,
new DynamicClientRegistrationError(
DynamicClientRegistrationErrors.InvalidClientMetadata,
"malformed metadata document")
"malformed metadata document"),
ct
);
/// <inheritdoc/>
public virtual async Task WriteError(HttpContext context, DynamicClientRegistrationError error) =>
await WriteResponse(context, StatusCodes.Status400BadRequest, error);
public virtual async Task WriteError(HttpContext context, DynamicClientRegistrationError error, CT ct) =>
await WriteResponse(context, StatusCodes.Status400BadRequest, error, ct);
/// <inheritdoc/>
public virtual async Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response) =>
await WriteResponse(context, StatusCodes.Status201Created, response);
public virtual async Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response, CT ct) =>
await WriteResponse(context, StatusCodes.Status201Created, response, ct);
}

View file

@ -20,32 +20,37 @@ public interface IDynamicClientRegistrationResponseGenerator
/// <param name="context">The HTTP context to write the response to.</param>
/// <param name="statusCode">The status code to set in the response.</param>
/// <param name="response">The response object to write to the response.</param>
Task WriteResponse<T>(HttpContext context, int statusCode, T response)
/// <param name="ct">The cancellation token.</param>
Task WriteResponse<T>(HttpContext context, int statusCode, T response, CT ct)
where T : IDynamicClientRegistrationResponse;
/// <summary>
/// Writes a content type error to the HTTP response.
/// </summary>
/// <param name="response">The HTTP context to write the error to.</param>
Task WriteContentTypeError(HttpContext response);
/// <param name="ct">The cancellation token.</param>
Task WriteContentTypeError(HttpContext response, CT ct);
/// <summary>
/// Writes a bad request error to the HTTP context.
/// </summary>
/// <param name="context">The HTTP context to write the error to.</param>
Task WriteBadRequestError(HttpContext context);
/// <param name="ct">The cancellation token.</param>
Task WriteBadRequestError(HttpContext context, CT ct);
/// <summary>
/// Writes a success response to the HTTP context.
/// </summary>
/// <param name="context">The HTTP context to write the response to.</param>
/// <param name="response">The dynamic client registration response.</param>
Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response);
/// <param name="ct">The cancellation token.</param>
Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response, CT ct);
/// <summary>
/// Writes a validation or processing step's error to the HTTP context.
/// </summary>
/// <param name="context">The HTTP context to write the error to.</param>
/// <param name="error">The dynamic client registration validation error.</param>
Task WriteError(HttpContext context, DynamicClientRegistrationError error);
/// <param name="ct">The cancellation token.</param>
Task WriteError(HttpContext context, DynamicClientRegistrationError error, CT ct);
}

View file

@ -15,5 +15,6 @@ public interface IClientConfigurationStore
/// Adds a client to the configuration store.
/// </summary>
/// <param name="client">The client to add to the store</param>
Task AddAsync(Client client);
/// <param name="ct">The cancellation token.</param>
Task AddAsync(Client client, CT ct);
}

View file

@ -24,7 +24,7 @@ public class InMemoryClientConfigurationStore : IClientConfigurationStore
/// registered in the DI system as an ICollection.</param>
public InMemoryClientConfigurationStore(ICollection<Client> clients) => _clients = clients;
/// <inheritdoc/>
public Task AddAsync(Client client)
public Task AddAsync(Client client, CT ct)
{
if (_clients.Select(c => c.ClientId).Contains(client.ClientId))
{

View file

@ -10,6 +10,8 @@ namespace IdentityServer.UnitTests.AspNetIdentity;
public class DefaultSessionClaimsFilterTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
[Fact]
public async Task FilterToSessionClaimsAsync_with_session_and_non_session_claims_should_filter_to_only_session_claims()
{
@ -26,7 +28,7 @@ public class DefaultSessionClaimsFilterTests
var filter = new DefaultSessionClaimsFilter();
var context = new SecurityStampRefreshingPrincipalContext() { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal };
var result = await filter.FilterToSessionClaimsAsync(context);
var result = await filter.FilterToSessionClaimsAsync(context, _ct);
var resultTypes = result.Select(c => c.Type).ToList();
resultTypes.Count.ShouldBe(3);
@ -51,7 +53,7 @@ public class DefaultSessionClaimsFilterTests
var filter = new DefaultSessionClaimsFilter();
var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal };
var result = await filter.FilterToSessionClaimsAsync(context);
var result = await filter.FilterToSessionClaimsAsync(context, _ct);
result.Count.ShouldBe(3);
string[] expectClaimTypes = [
@ -75,7 +77,7 @@ public class DefaultSessionClaimsFilterTests
var filter = new DefaultSessionClaimsFilter();
var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal };
var result = await filter.FilterToSessionClaimsAsync(context);
var result = await filter.FilterToSessionClaimsAsync(context, _ct);
result.ShouldBeEmpty();
}
@ -88,7 +90,7 @@ public class DefaultSessionClaimsFilterTests
var filter = new DefaultSessionClaimsFilter();
var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal };
var result = await filter.FilterToSessionClaimsAsync(context);
var result = await filter.FilterToSessionClaimsAsync(context, _ct);
result.ShouldBeEmpty();
}