diff --git a/identity-server/src/AspNetIdentity/DefaultSessionClaimsFilter.cs b/identity-server/src/AspNetIdentity/DefaultSessionClaimsFilter.cs index 0845f2e23..1856f6d8a 100644 --- a/identity-server/src/AspNetIdentity/DefaultSessionClaimsFilter.cs +++ b/identity-server/src/AspNetIdentity/DefaultSessionClaimsFilter.cs @@ -9,7 +9,7 @@ namespace Duende.IdentityServer.AspNetIdentity; public class DefaultSessionClaimsFilter : ISessionClaimsFilter { /// - public Task> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context) + public Task> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context, CT ct) { var newClaimTypes = context.NewPrincipal.Claims.Select(x => x.Type).ToArray(); var currentClaimsToKeep = context.CurrentPrincipal.Claims.Where(x => !newClaimTypes.Contains(x.Type)).ToArray(); diff --git a/identity-server/src/AspNetIdentity/ISessionClaimsFilter.cs b/identity-server/src/AspNetIdentity/ISessionClaimsFilter.cs index 898871a0a..43a2b76bd 100644 --- a/identity-server/src/AspNetIdentity/ISessionClaimsFilter.cs +++ b/identity-server/src/AspNetIdentity/ISessionClaimsFilter.cs @@ -16,6 +16,7 @@ public interface ISessionClaimsFilter /// /// The SecurityStampRefreshingPrincipalContext /// in the call to . + /// The cancellation token. /// The claims of the ClaimsPrincipal which should be persisted for the session. - public Task> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context); + public Task> FilterToSessionClaimsAsync(SecurityStampRefreshingPrincipalContext context, CT ct); } diff --git a/identity-server/src/AspNetIdentity/SecurityStampValidatorCallback.cs b/identity-server/src/AspNetIdentity/SecurityStampValidatorCallback.cs index aeb823bc3..54c5b8d81 100644 --- a/identity-server/src/AspNetIdentity/SecurityStampValidatorCallback.cs +++ b/identity-server/src/AspNetIdentity/SecurityStampValidatorCallback.cs @@ -26,7 +26,7 @@ public static class SecurityStampValidatorCallback return; } - var currentClaimsToKeep = await sessionClaimsFilter.FilterToSessionClaimsAsync(context); + var currentClaimsToKeep = await sessionClaimsFilter.FilterToSessionClaimsAsync(context, default); var id = context.NewPrincipal.Identities.First(); id.AddClaims(currentClaimsToKeep); diff --git a/identity-server/src/Configuration.EntityFramework/ClientConfigurationStore.cs b/identity-server/src/Configuration.EntityFramework/ClientConfigurationStore.cs index fd733c96d..6cadb30b7 100644 --- a/identity-server/src/Configuration.EntityFramework/ClientConfigurationStore.cs +++ b/identity-server/src/Configuration.EntityFramework/ClientConfigurationStore.cs @@ -46,10 +46,10 @@ public class ClientConfigurationStore : IClientConfigurationStore } /// - public async Task AddAsync(Client client) + public async Task AddAsync(Client client, CT ct) { Logger.LogDebug("Adding client {ClientId} to configuration store", client.ClientId); DbContext.Clients.Add(client.ToEntity()); - await DbContext.SaveChangesAsync(CancellationTokenProvider.CancellationToken); + await DbContext.SaveChangesAsync(ct); } } diff --git a/identity-server/src/Configuration/Endpoints/DynamicClientRegistrationEndpoint.cs b/identity-server/src/Configuration/Endpoints/DynamicClientRegistrationEndpoint.cs index 018954252..1da4ab7b2 100644 --- a/identity-server/src/Configuration/Endpoints/DynamicClientRegistrationEndpoint.cs +++ b/identity-server/src/Configuration/Endpoints/DynamicClientRegistrationEndpoint.cs @@ -47,7 +47,7 @@ public class DynamicClientRegistrationEndpoint // Check content type if (!HasCorrectContentType(httpContext.Request)) { - await _responseGenerator.WriteContentTypeError(httpContext); + await _responseGenerator.WriteContentTypeError(httpContext, httpContext.RequestAborted); return; } @@ -55,7 +55,7 @@ public class DynamicClientRegistrationEndpoint var request = await TryParseAsync(httpContext.Request); if (request == null) { - await _responseGenerator.WriteBadRequestError(httpContext); + await _responseGenerator.WriteBadRequestError(httpContext, httpContext.RequestAborted); return; } @@ -66,18 +66,18 @@ public class DynamicClientRegistrationEndpoint if (validationResult is DynamicClientRegistrationError validationError) { - await _responseGenerator.WriteError(httpContext, validationError); + await _responseGenerator.WriteError(httpContext, validationError, httpContext.RequestAborted); } else { var processingResult = await _processor.ProcessAsync(dcrContext, httpContext.RequestAborted); if (processingResult is DynamicClientRegistrationError processingFailure) { - await _responseGenerator.WriteError(httpContext, processingFailure); + await _responseGenerator.WriteError(httpContext, processingFailure, httpContext.RequestAborted); } else if (processingResult is DynamicClientRegistrationResponse success) { - await _responseGenerator.WriteSuccessResponse(httpContext, success); + await _responseGenerator.WriteSuccessResponse(httpContext, success, httpContext.RequestAborted); } else { diff --git a/identity-server/src/Configuration/RequestProcessing/DynamicClientRegistrationRequestProcessor.cs b/identity-server/src/Configuration/RequestProcessing/DynamicClientRegistrationRequestProcessor.cs index 14e7d8057..fe778bea8 100644 --- a/identity-server/src/Configuration/RequestProcessing/DynamicClientRegistrationRequestProcessor.cs +++ b/identity-server/src/Configuration/RequestProcessing/DynamicClientRegistrationRequestProcessor.cs @@ -64,7 +64,7 @@ public class DynamicClientRegistrationRequestProcessor : IDynamicClientRegistrat } } - await Store.AddAsync(context.Client); + await Store.AddAsync(context.Client, ct); return new DynamicClientRegistrationResponse(context.Request, context.Client) { diff --git a/identity-server/src/Configuration/ResponseGeneration/DynamicClientRegistrationResponseGenerator.cs b/identity-server/src/Configuration/ResponseGeneration/DynamicClientRegistrationResponseGenerator.cs index 96840b064..7091b3849 100644 --- a/identity-server/src/Configuration/ResponseGeneration/DynamicClientRegistrationResponseGenerator.cs +++ b/identity-server/src/Configuration/ResponseGeneration/DynamicClientRegistrationResponseGenerator.cs @@ -31,15 +31,15 @@ public class DynamicClientRegistrationResponseGenerator : IDynamicClientRegistra public DynamicClientRegistrationResponseGenerator(ILogger logger) => Logger = logger; /// - public virtual async Task WriteResponse(HttpContext context, int statusCode, T response) + public virtual async Task WriteResponse(HttpContext context, int statusCode, T response, CT ct) where T : IDynamicClientRegistrationResponse { context.Response.StatusCode = statusCode; - await context.Response.WriteAsJsonAsync(response, SerializerOptions); + await context.Response.WriteAsJsonAsync(response, SerializerOptions, ct); } /// - public virtual Task WriteContentTypeError(HttpContext context) + public virtual Task WriteContentTypeError(HttpContext context, CT ct) { Logger.LogDebug("Invalid content type in dynamic client registration request"); context.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType; @@ -47,19 +47,20 @@ public class DynamicClientRegistrationResponseGenerator : IDynamicClientRegistra } /// - public virtual async Task WriteBadRequestError(HttpContext context) => + public virtual async Task WriteBadRequestError(HttpContext context, CT ct) => await WriteResponse(context, StatusCodes.Status400BadRequest, new DynamicClientRegistrationError( DynamicClientRegistrationErrors.InvalidClientMetadata, - "malformed metadata document") + "malformed metadata document"), + ct ); /// - public virtual async Task WriteError(HttpContext context, DynamicClientRegistrationError error) => - await WriteResponse(context, StatusCodes.Status400BadRequest, error); + public virtual async Task WriteError(HttpContext context, DynamicClientRegistrationError error, CT ct) => + await WriteResponse(context, StatusCodes.Status400BadRequest, error, ct); /// - public virtual async Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response) => - await WriteResponse(context, StatusCodes.Status201Created, response); + public virtual async Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response, CT ct) => + await WriteResponse(context, StatusCodes.Status201Created, response, ct); } diff --git a/identity-server/src/Configuration/ResponseGeneration/IDynamicClientRegistrationResponseGenerator.cs b/identity-server/src/Configuration/ResponseGeneration/IDynamicClientRegistrationResponseGenerator.cs index 9ff850264..cc1fd6cb8 100644 --- a/identity-server/src/Configuration/ResponseGeneration/IDynamicClientRegistrationResponseGenerator.cs +++ b/identity-server/src/Configuration/ResponseGeneration/IDynamicClientRegistrationResponseGenerator.cs @@ -20,32 +20,37 @@ public interface IDynamicClientRegistrationResponseGenerator /// The HTTP context to write the response to. /// The status code to set in the response. /// The response object to write to the response. - Task WriteResponse(HttpContext context, int statusCode, T response) + /// The cancellation token. + Task WriteResponse(HttpContext context, int statusCode, T response, CT ct) where T : IDynamicClientRegistrationResponse; /// /// Writes a content type error to the HTTP response. /// /// The HTTP context to write the error to. - Task WriteContentTypeError(HttpContext response); + /// The cancellation token. + Task WriteContentTypeError(HttpContext response, CT ct); /// /// Writes a bad request error to the HTTP context. /// /// The HTTP context to write the error to. - Task WriteBadRequestError(HttpContext context); + /// The cancellation token. + Task WriteBadRequestError(HttpContext context, CT ct); /// /// Writes a success response to the HTTP context. /// /// The HTTP context to write the response to. /// The dynamic client registration response. - Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response); + /// The cancellation token. + Task WriteSuccessResponse(HttpContext context, DynamicClientRegistrationResponse response, CT ct); /// /// Writes a validation or processing step's error to the HTTP context. /// /// The HTTP context to write the error to. /// The dynamic client registration validation error. - Task WriteError(HttpContext context, DynamicClientRegistrationError error); + /// The cancellation token. + Task WriteError(HttpContext context, DynamicClientRegistrationError error, CT ct); } diff --git a/identity-server/src/Configuration/Stores/IClientConfigurationStore.cs b/identity-server/src/Configuration/Stores/IClientConfigurationStore.cs index 38eb66768..3ab0b3536 100644 --- a/identity-server/src/Configuration/Stores/IClientConfigurationStore.cs +++ b/identity-server/src/Configuration/Stores/IClientConfigurationStore.cs @@ -15,5 +15,6 @@ public interface IClientConfigurationStore /// Adds a client to the configuration store. /// /// The client to add to the store - Task AddAsync(Client client); + /// The cancellation token. + Task AddAsync(Client client, CT ct); } diff --git a/identity-server/src/Configuration/Stores/InMemoryClientConfigurationStore.cs b/identity-server/src/Configuration/Stores/InMemoryClientConfigurationStore.cs index f1d6b0704..19c2feddf 100644 --- a/identity-server/src/Configuration/Stores/InMemoryClientConfigurationStore.cs +++ b/identity-server/src/Configuration/Stores/InMemoryClientConfigurationStore.cs @@ -24,7 +24,7 @@ public class InMemoryClientConfigurationStore : IClientConfigurationStore /// registered in the DI system as an ICollection. public InMemoryClientConfigurationStore(ICollection clients) => _clients = clients; /// - public Task AddAsync(Client client) + public Task AddAsync(Client client, CT ct) { if (_clients.Select(c => c.ClientId).Contains(client.ClientId)) { diff --git a/identity-server/test/IdentityServer.UnitTests/AspNetIdentity/DefaultSessionClaimsFilterTests.cs b/identity-server/test/IdentityServer.UnitTests/AspNetIdentity/DefaultSessionClaimsFilterTests.cs index 1303be50b..e5d3ba919 100644 --- a/identity-server/test/IdentityServer.UnitTests/AspNetIdentity/DefaultSessionClaimsFilterTests.cs +++ b/identity-server/test/IdentityServer.UnitTests/AspNetIdentity/DefaultSessionClaimsFilterTests.cs @@ -10,6 +10,8 @@ namespace IdentityServer.UnitTests.AspNetIdentity; public class DefaultSessionClaimsFilterTests { + private readonly CT _ct = TestContext.Current.CancellationToken; + [Fact] public async Task FilterToSessionClaimsAsync_with_session_and_non_session_claims_should_filter_to_only_session_claims() { @@ -26,7 +28,7 @@ public class DefaultSessionClaimsFilterTests var filter = new DefaultSessionClaimsFilter(); var context = new SecurityStampRefreshingPrincipalContext() { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal }; - var result = await filter.FilterToSessionClaimsAsync(context); + var result = await filter.FilterToSessionClaimsAsync(context, _ct); var resultTypes = result.Select(c => c.Type).ToList(); resultTypes.Count.ShouldBe(3); @@ -51,7 +53,7 @@ public class DefaultSessionClaimsFilterTests var filter = new DefaultSessionClaimsFilter(); var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal }; - var result = await filter.FilterToSessionClaimsAsync(context); + var result = await filter.FilterToSessionClaimsAsync(context, _ct); result.Count.ShouldBe(3); string[] expectClaimTypes = [ @@ -75,7 +77,7 @@ public class DefaultSessionClaimsFilterTests var filter = new DefaultSessionClaimsFilter(); var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal }; - var result = await filter.FilterToSessionClaimsAsync(context); + var result = await filter.FilterToSessionClaimsAsync(context, _ct); result.ShouldBeEmpty(); } @@ -88,7 +90,7 @@ public class DefaultSessionClaimsFilterTests var filter = new DefaultSessionClaimsFilter(); var context = new SecurityStampRefreshingPrincipalContext { NewPrincipal = newPrincipal, CurrentPrincipal = currentPrincipal }; - var result = await filter.FilterToSessionClaimsAsync(context); + var result = await filter.FilterToSessionClaimsAsync(context, _ct); result.ShouldBeEmpty(); }