diff --git a/bff/src/Bff/SessionManagement/TicketStore/ServerSideTicketStore.cs b/bff/src/Bff/SessionManagement/TicketStore/ServerSideTicketStore.cs index 51435b27a..93a0ea875 100644 --- a/bff/src/Bff/SessionManagement/TicketStore/ServerSideTicketStore.cs +++ b/bff/src/Bff/SessionManagement/TicketStore/ServerSideTicketStore.cs @@ -31,7 +31,7 @@ internal class ServerSideTicketStore( private readonly IDataProtector _protector = dataProtectionProvider.CreateProtector(DataProtectorPurpose); - private CT ct => accessor.HttpContext?.RequestAborted ?? CT.None; + private CT _ct => accessor.HttpContext?.RequestAborted ?? CT.None; /// public async Task StoreAsync(AuthenticationTicket ticket) @@ -43,7 +43,7 @@ internal class ServerSideTicketStore( { SubjectId = ticket.GetSubjectId(), SessionId = ticket.GetSessionId() - }, ct); + }, _ct); var key = CryptoRandom.CreateUniqueId(format: CryptoRandom.OutputFormat.Hex); @@ -68,7 +68,7 @@ internal class ServerSideTicketStore( Ticket = ticket.Serialize(_protector) }; - await store.CreateUserSessionAsync(session, ct); + await store.CreateUserSessionAsync(session, _ct); metrics.SessionStarted(); } @@ -78,7 +78,7 @@ internal class ServerSideTicketStore( logger.RetrieveAuthenticationTicket(LogLevel.Debug, key); var userSessionKey = BuildUserSessionKey(key); - var session = await store.GetUserSessionAsync(userSessionKey, ct); + var session = await store.GetUserSessionAsync(userSessionKey, _ct); if (session == null) { logger.NoAuthenticationTicketFoundForKey(LogLevel.Debug, key); @@ -111,7 +111,7 @@ internal class ServerSideTicketStore( public async Task RenewAsync(string key, AuthenticationTicket ticket) { var userSessionKey = BuildUserSessionKey(key); - var session = await store.GetUserSessionAsync(userSessionKey, ct); + var session = await store.GetUserSessionAsync(userSessionKey, _ct); if (session == null) { // https://github.com/dotnet/aspnetcore/issues/41516#issuecomment-1178076544 @@ -134,7 +134,7 @@ internal class ServerSideTicketStore( Renewed = ticket.GetIssued(timeProvider.GetUtcNow()), Expires = ticket.GetExpiration(), Ticket = ticket.Serialize(_protector) - }, ct); + }, _ct); } /// @@ -150,7 +150,7 @@ internal class ServerSideTicketStore( logger.RemovingAuthenticationTicket(LogLevel.Debug, userSessionKey.ToString()); metrics.SessionEnded(); - return store.DeleteUserSessionAsync(userSessionKey, ct); + return store.DeleteUserSessionAsync(userSessionKey, _ct); } /// diff --git a/identity-server/src/IdentityServer/Hosting/IdentityServerAuthenticationService.cs b/identity-server/src/IdentityServer/Hosting/IdentityServerAuthenticationService.cs index 3011f05ac..c2dfd87a0 100644 --- a/identity-server/src/IdentityServer/Hosting/IdentityServerAuthenticationService.cs +++ b/identity-server/src/IdentityServer/Hosting/IdentityServerAuthenticationService.cs @@ -61,7 +61,7 @@ internal class IdentityServerAuthenticationService : IAuthenticationService AugmentPrincipal(principal); properties ??= new AuthenticationProperties(); - await _session.CreateSessionIdAsync(principal, properties, default); + await _session.CreateSessionIdAsync(principal, properties, context.RequestAborted); } await _inner.SignInAsync(context, scheme, principal, properties); @@ -96,22 +96,22 @@ internal class IdentityServerAuthenticationService : IAuthenticationService _logger.LogDebug("SignOutCalled set; processing post-signout session cleanup."); // back channel logout - var user = await _session.GetUserAsync(default); + var user = await _session.GetUserAsync(context.RequestAborted); if (user != null) { var session = new UserSession { SubjectId = user.GetSubjectId(), - SessionId = await _session.GetSessionIdAsync(default), + SessionId = await _session.GetSessionIdAsync(context.RequestAborted), DisplayName = user.GetDisplayName(), - ClientIds = (await _session.GetClientListAsync(default)).ToList(), + ClientIds = (await _session.GetClientListAsync(context.RequestAborted)).ToList(), Issuer = await _issuerNameService.GetCurrentAsync(context.RequestAborted) }; await _sessionCoordinationService.ProcessLogoutAsync(session, context.RequestAborted); } // this clears our session id cookie so JS clients can detect the user has signed out - await _session.RemoveSessionIdCookieAsync(default); + await _session.RemoveSessionIdCookieAsync(context.RequestAborted); }); context.SetBackChannelLogoutTriggered(); diff --git a/identity-server/src/IdentityServer/IdentityServerTools.cs b/identity-server/src/IdentityServer/IdentityServerTools.cs index 755547c2f..eaf41cc23 100644 --- a/identity-server/src/IdentityServer/IdentityServerTools.cs +++ b/identity-server/src/IdentityServer/IdentityServerTools.cs @@ -27,6 +27,7 @@ public interface IIdentityServerTools /// the exp claim of the token. /// A collection of additional claims to include in the /// token. + /// The cancellation token. /// A JWT that expires after the specified lifetime and contains /// the given claims. /// Typical implementations depend on the @@ -34,7 +35,7 @@ public interface IIdentityServerTools /// of the token. Ensure that calls to this method will only occur if there /// is an incoming HTTP request or with the option set. /// - Task IssueJwtAsync(int lifetime, IEnumerable claims); + Task IssueJwtAsync(int lifetime, IEnumerable claims, CT ct); /// /// Issues a JWT with a specific lifetime, issuer, and set of claims. @@ -45,9 +46,10 @@ public interface IIdentityServerTools /// claim. /// A collection of additional claims to include in the /// token. + /// The cancellation token. /// A JWT with the specified lifetime, issuer and additional /// claims. - Task IssueJwtAsync(int lifetime, string issuer, IEnumerable claims); + Task IssueJwtAsync(int lifetime, string issuer, IEnumerable claims, CT ct); /// /// Issues a JWT with a specific lifetime, issuer, token type, and set of @@ -61,9 +63,10 @@ public interface IIdentityServerTools /// "id_token", set in the typ claim. /// A collection of additional claims to include in the /// token. + /// The cancellation token. /// A JWT with the specified lifetime, issuer, token type, and /// additional claims. - Task IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable claims); + Task IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable claims, CT ct); /// /// Issues a JWT access token for a particular client. @@ -72,6 +75,7 @@ public interface IIdentityServerTools /// claim. /// The lifetime, in seconds, which will determine /// the exp claim of the token. + /// The cancellation token. /// A collection of scopes, which will be added to the /// token as claims with the "scope" type. /// A collection of audiences, which will be added @@ -88,6 +92,7 @@ public interface IIdentityServerTools Task IssueClientJwtAsync( string clientId, int lifetime, + CT ct, IEnumerable? scopes = null, IEnumerable? audiences = null, IEnumerable? additionalClaims = null); @@ -113,21 +118,21 @@ public class IdentityServerTools : IIdentityServerTools } /// - public virtual async Task IssueJwtAsync(int lifetime, IEnumerable claims) + public virtual async Task IssueJwtAsync(int lifetime, IEnumerable claims, CT ct) { - var issuer = await _issuerNameService.GetCurrentAsync(default); - return await IssueJwtAsync(lifetime, issuer, claims); + var issuer = await _issuerNameService.GetCurrentAsync(ct); + return await IssueJwtAsync(lifetime, issuer, claims, ct); } /// - public virtual Task IssueJwtAsync(int lifetime, string issuer, IEnumerable claims) + public virtual Task IssueJwtAsync(int lifetime, string issuer, IEnumerable claims, CT ct) { var tokenType = OidcConstants.TokenTypes.AccessToken; - return IssueJwtAsync(lifetime, issuer, tokenType, claims); + return IssueJwtAsync(lifetime, issuer, tokenType, claims, ct); } /// - public virtual async Task IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable claims) + public virtual async Task IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable claims, CT ct) { ArgumentException.ThrowIfNullOrWhiteSpace(issuer); ArgumentException.ThrowIfNullOrWhiteSpace(tokenType); @@ -142,13 +147,14 @@ public class IdentityServerTools : IIdentityServerTools Claims = new HashSet(claims, new ClaimComparer()) }; - return await _tokenCreation.CreateTokenAsync(token, default); + return await _tokenCreation.CreateTokenAsync(token, ct); } /// public virtual async Task IssueClientJwtAsync( string clientId, int lifetime, + CT ct, IEnumerable? scopes = null, IEnumerable? audiences = null, IEnumerable? additionalClaims = null) @@ -178,7 +184,7 @@ public class IdentityServerTools : IIdentityServerTools claims.Add(new Claim( JwtClaimTypes.Audience, #pragma warning disable CA1863 // Would require changing a public const on a public class and be a breaking change - string.Format(CultureInfo.InvariantCulture, IdentityServerConstants.AccessTokenAudience, (await _issuerNameService.GetCurrentAsync(default)).EnsureTrailingSlash()))); + string.Format(CultureInfo.InvariantCulture, IdentityServerConstants.AccessTokenAudience, (await _issuerNameService.GetCurrentAsync(ct)).EnsureTrailingSlash()))); #pragma warning restore CA1863 } @@ -190,6 +196,6 @@ public class IdentityServerTools : IIdentityServerTools } } - return await IssueJwtAsync(lifetime, claims); + return await IssueJwtAsync(lifetime, claims, ct); } } diff --git a/identity-server/src/IdentityServer/Services/Default/DefaultBackChannelLogoutService.cs b/identity-server/src/IdentityServer/Services/Default/DefaultBackChannelLogoutService.cs index 5261d73fa..1e08c1768 100644 --- a/identity-server/src/IdentityServer/Services/Default/DefaultBackChannelLogoutService.cs +++ b/identity-server/src/IdentityServer/Services/Default/DefaultBackChannelLogoutService.cs @@ -103,7 +103,7 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService // implementation doesn't make parallel use of a single DB context. // Since the signing key material should be cached, only the // first serial operation will call the db. - var payload = await CreateFormPostPayloadAsync(backChannelLogoutRequest); + var payload = await CreateFormPostPayloadAsync(backChannelLogoutRequest, ct); logoutRequestsWithPayload.Add((backChannelLogoutRequest, payload)); } @@ -124,10 +124,11 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService /// Creates the form-url-encoded payload (as a dictionary) to send to the client. /// /// + /// The cancellation token. /// - protected async Task> CreateFormPostPayloadAsync(BackChannelLogoutRequest request) + protected async Task> CreateFormPostPayloadAsync(BackChannelLogoutRequest request, CT ct) { - var token = await CreateTokenAsync(request); + var token = await CreateTokenAsync(request, ct); var data = new Dictionary { @@ -140,8 +141,9 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService /// Creates the JWT used for the back-channel logout notification. /// /// + /// The cancellation token. /// The token. - protected virtual async Task CreateTokenAsync(BackChannelLogoutRequest request) + protected virtual async Task CreateTokenAsync(BackChannelLogoutRequest request, CT ct) { var claims = await CreateClaimsForTokenAsync(request); if (claims.Any(x => x.Type == JwtClaimTypes.Nonce)) @@ -151,11 +153,11 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService if (request.Issuer != null) { - return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, request.Issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims); + return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, request.Issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims, ct); } - var issuer = await IssuerNameService.GetCurrentAsync(default); - return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims); + var issuer = await IssuerNameService.GetCurrentAsync(ct); + return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims, ct); } /// diff --git a/identity-server/src/IdentityServer/Stores/Default/ServerSideTicketStore.cs b/identity-server/src/IdentityServer/Stores/Default/ServerSideTicketStore.cs index fbf9ede93..9f89c0476 100644 --- a/identity-server/src/IdentityServer/Stores/Default/ServerSideTicketStore.cs +++ b/identity-server/src/IdentityServer/Stores/Default/ServerSideTicketStore.cs @@ -58,7 +58,7 @@ public class ServerSideTicketStore : IServerSideTicketStore ArgumentNullException.ThrowIfNull(ticket); - ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(default)); + ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(_httpContextAccessor.HttpContext?.RequestAborted ?? default)); var key = CryptoRandom.CreateUniqueId(format: CryptoRandom.OutputFormat.Hex); @@ -149,7 +149,7 @@ public class ServerSideTicketStore : IServerSideTicketStore if (ticket.GetIssuer() == null) { // when issuing a new cookie on top of an existing cookie, the AuthenticationTicket passed above is new (and not the prior one loaded from the ticket store) - ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(default)); + ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(_httpContextAccessor.HttpContext?.RequestAborted ?? default)); } session.Renewed = ticket.GetIssued(); session.Expires = ticket.GetExpiration(); diff --git a/identity-server/test/IdentityServer.IntegrationTests/Endpoints/Ciba/CibaTests.cs b/identity-server/test/IdentityServer.IntegrationTests/Endpoints/Ciba/CibaTests.cs index 557909451..a77d1ad0b 100644 --- a/identity-server/test/IdentityServer.IntegrationTests/Endpoints/Ciba/CibaTests.cs +++ b/identity-server/test/IdentityServer.IntegrationTests/Endpoints/Ciba/CibaTests.cs @@ -21,6 +21,8 @@ public class CibaTests { private const string Category = "Backchannel Authentication (CIBA) endpoint"; + private readonly CT _ct = TestContext.Current.CancellationToken; + private IdentityServerPipeline _mockPipeline = new(); private MockCibaUserValidator _mockCibaUserValidator = new(); private MockCibaUserNotificationService _mockCibaUserNotificationService = new(); @@ -1513,7 +1515,7 @@ public class CibaTests var id_token = await tokenService.IssueJwtAsync(600, new Claim[] { new Claim("sub", _user.SubjectId), new Claim("aud", _cibaClient.ClientId), - }); + }, _ct); var bindingMessage = Guid.NewGuid().ToString("n"); var body = new Dictionary diff --git a/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultBackChannelLogoutServiceTests.cs b/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultBackChannelLogoutServiceTests.cs index d8007221f..cb57041d5 100644 --- a/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultBackChannelLogoutServiceTests.cs +++ b/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultBackChannelLogoutServiceTests.cs @@ -17,6 +17,7 @@ namespace UnitTests.Services.Default; public class DefaultBackChannelLogoutServiceTests { + private readonly CT _ct = TestContext.Current.CancellationToken; private class ServiceTestHarness : DefaultBackChannelLogoutService { public ServiceTestHarness( @@ -32,7 +33,7 @@ public class DefaultBackChannelLogoutServiceTests // CreateTokenAsync is protected, so we use this wrapper to exercise it in our tests - public async Task ExerciseCreateTokenAsync(BackChannelLogoutRequest request) => await CreateTokenAsync(request); + public async Task ExerciseCreateTokenAsync(BackChannelLogoutRequest request, CT ct) => await CreateTokenAsync(request, ct); } [Fact] @@ -59,7 +60,7 @@ public class DefaultBackChannelLogoutServiceTests { ClientId = "test_client", SubjectId = "test_sub", - }); + }, _ct); var payload = JsonSerializer.Deserialize>(Base64Url.DecodeFromChars(rawToken.Split('.')[1]));