Resolve default CT stopgaps — add CT to IIdentityServerTools, thread CT through DefaultBackChannelLogoutService, use context.RequestAborted in IdentityServerAuthenticationService and ServerSideTicketStore

This commit is contained in:
Damian Hickey 2026-02-21 13:42:08 +01:00
parent 2c0994b18b
commit d7ac8ee981
7 changed files with 47 additions and 36 deletions

View file

@ -31,7 +31,7 @@ internal class ServerSideTicketStore(
private readonly IDataProtector _protector = dataProtectionProvider.CreateProtector(DataProtectorPurpose);
private CT ct => accessor.HttpContext?.RequestAborted ?? CT.None;
private CT _ct => accessor.HttpContext?.RequestAborted ?? CT.None;
/// <inheritdoc />
public async Task<string> StoreAsync(AuthenticationTicket ticket)
@ -43,7 +43,7 @@ internal class ServerSideTicketStore(
{
SubjectId = ticket.GetSubjectId(),
SessionId = ticket.GetSessionId()
}, ct);
}, _ct);
var key = CryptoRandom.CreateUniqueId(format: CryptoRandom.OutputFormat.Hex);
@ -68,7 +68,7 @@ internal class ServerSideTicketStore(
Ticket = ticket.Serialize(_protector)
};
await store.CreateUserSessionAsync(session, ct);
await store.CreateUserSessionAsync(session, _ct);
metrics.SessionStarted();
}
@ -78,7 +78,7 @@ internal class ServerSideTicketStore(
logger.RetrieveAuthenticationTicket(LogLevel.Debug, key);
var userSessionKey = BuildUserSessionKey(key);
var session = await store.GetUserSessionAsync(userSessionKey, ct);
var session = await store.GetUserSessionAsync(userSessionKey, _ct);
if (session == null)
{
logger.NoAuthenticationTicketFoundForKey(LogLevel.Debug, key);
@ -111,7 +111,7 @@ internal class ServerSideTicketStore(
public async Task RenewAsync(string key, AuthenticationTicket ticket)
{
var userSessionKey = BuildUserSessionKey(key);
var session = await store.GetUserSessionAsync(userSessionKey, ct);
var session = await store.GetUserSessionAsync(userSessionKey, _ct);
if (session == null)
{
// https://github.com/dotnet/aspnetcore/issues/41516#issuecomment-1178076544
@ -134,7 +134,7 @@ internal class ServerSideTicketStore(
Renewed = ticket.GetIssued(timeProvider.GetUtcNow()),
Expires = ticket.GetExpiration(),
Ticket = ticket.Serialize(_protector)
}, ct);
}, _ct);
}
/// <inheritdoc />
@ -150,7 +150,7 @@ internal class ServerSideTicketStore(
logger.RemovingAuthenticationTicket(LogLevel.Debug, userSessionKey.ToString());
metrics.SessionEnded();
return store.DeleteUserSessionAsync(userSessionKey, ct);
return store.DeleteUserSessionAsync(userSessionKey, _ct);
}
/// <inheritdoc />

View file

@ -61,7 +61,7 @@ internal class IdentityServerAuthenticationService : IAuthenticationService
AugmentPrincipal(principal);
properties ??= new AuthenticationProperties();
await _session.CreateSessionIdAsync(principal, properties, default);
await _session.CreateSessionIdAsync(principal, properties, context.RequestAborted);
}
await _inner.SignInAsync(context, scheme, principal, properties);
@ -96,22 +96,22 @@ internal class IdentityServerAuthenticationService : IAuthenticationService
_logger.LogDebug("SignOutCalled set; processing post-signout session cleanup.");
// back channel logout
var user = await _session.GetUserAsync(default);
var user = await _session.GetUserAsync(context.RequestAborted);
if (user != null)
{
var session = new UserSession
{
SubjectId = user.GetSubjectId(),
SessionId = await _session.GetSessionIdAsync(default),
SessionId = await _session.GetSessionIdAsync(context.RequestAborted),
DisplayName = user.GetDisplayName(),
ClientIds = (await _session.GetClientListAsync(default)).ToList(),
ClientIds = (await _session.GetClientListAsync(context.RequestAborted)).ToList(),
Issuer = await _issuerNameService.GetCurrentAsync(context.RequestAborted)
};
await _sessionCoordinationService.ProcessLogoutAsync(session, context.RequestAborted);
}
// this clears our session id cookie so JS clients can detect the user has signed out
await _session.RemoveSessionIdCookieAsync(default);
await _session.RemoveSessionIdCookieAsync(context.RequestAborted);
});
context.SetBackChannelLogoutTriggered();

View file

@ -27,6 +27,7 @@ public interface IIdentityServerTools
/// the exp claim of the token.</param>
/// <param name="claims">A collection of additional claims to include in the
/// token.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>A JWT that expires after the specified lifetime and contains
/// the given claims.</returns>
/// <remarks>Typical implementations depend on the <see cref="HttpContext"/>
@ -34,7 +35,7 @@ public interface IIdentityServerTools
/// of the token. Ensure that calls to this method will only occur if there
/// is an incoming HTTP request or with the option set.
/// </remarks>
Task<string> IssueJwtAsync(int lifetime, IEnumerable<Claim> claims);
Task<string> IssueJwtAsync(int lifetime, IEnumerable<Claim> claims, CT ct);
/// <summary>
/// Issues a JWT with a specific lifetime, issuer, and set of claims.
@ -45,9 +46,10 @@ public interface IIdentityServerTools
/// claim.</param>
/// <param name="claims">A collection of additional claims to include in the
/// token.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>A JWT with the specified lifetime, issuer and additional
/// claims.</returns>
Task<string> IssueJwtAsync(int lifetime, string issuer, IEnumerable<Claim> claims);
Task<string> IssueJwtAsync(int lifetime, string issuer, IEnumerable<Claim> claims, CT ct);
/// <summary>
/// Issues a JWT with a specific lifetime, issuer, token type, and set of
@ -61,9 +63,10 @@ public interface IIdentityServerTools
/// "id_token", set in the typ claim.</param>
/// <param name="claims">A collection of additional claims to include in the
/// token.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>A JWT with the specified lifetime, issuer, token type, and
/// additional claims.</returns>
Task<string> IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable<Claim> claims);
Task<string> IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable<Claim> claims, CT ct);
/// <summary>
/// Issues a JWT access token for a particular client.
@ -72,6 +75,7 @@ public interface IIdentityServerTools
/// claim.</param>
/// <param name="lifetime">The lifetime, in seconds, which will determine
/// the exp claim of the token.</param>
/// <param name="ct">The cancellation token.</param>
/// <param name="scopes">A collection of scopes, which will be added to the
/// token as claims with the "scope" type.</param>
/// <param name="audiences">A collection of audiences, which will be added
@ -88,6 +92,7 @@ public interface IIdentityServerTools
Task<string> IssueClientJwtAsync(
string clientId,
int lifetime,
CT ct,
IEnumerable<string>? scopes = null,
IEnumerable<string>? audiences = null,
IEnumerable<Claim>? additionalClaims = null);
@ -113,21 +118,21 @@ public class IdentityServerTools : IIdentityServerTools
}
/// <inheritdoc/>
public virtual async Task<string> IssueJwtAsync(int lifetime, IEnumerable<Claim> claims)
public virtual async Task<string> IssueJwtAsync(int lifetime, IEnumerable<Claim> claims, CT ct)
{
var issuer = await _issuerNameService.GetCurrentAsync(default);
return await IssueJwtAsync(lifetime, issuer, claims);
var issuer = await _issuerNameService.GetCurrentAsync(ct);
return await IssueJwtAsync(lifetime, issuer, claims, ct);
}
/// <inheritdoc/>
public virtual Task<string> IssueJwtAsync(int lifetime, string issuer, IEnumerable<Claim> claims)
public virtual Task<string> IssueJwtAsync(int lifetime, string issuer, IEnumerable<Claim> claims, CT ct)
{
var tokenType = OidcConstants.TokenTypes.AccessToken;
return IssueJwtAsync(lifetime, issuer, tokenType, claims);
return IssueJwtAsync(lifetime, issuer, tokenType, claims, ct);
}
/// <inheritdoc/>
public virtual async Task<string> IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable<Claim> claims)
public virtual async Task<string> IssueJwtAsync(int lifetime, string issuer, string tokenType, IEnumerable<Claim> claims, CT ct)
{
ArgumentException.ThrowIfNullOrWhiteSpace(issuer);
ArgumentException.ThrowIfNullOrWhiteSpace(tokenType);
@ -142,13 +147,14 @@ public class IdentityServerTools : IIdentityServerTools
Claims = new HashSet<Claim>(claims, new ClaimComparer())
};
return await _tokenCreation.CreateTokenAsync(token, default);
return await _tokenCreation.CreateTokenAsync(token, ct);
}
/// <inheritdoc/>
public virtual async Task<string> IssueClientJwtAsync(
string clientId,
int lifetime,
CT ct,
IEnumerable<string>? scopes = null,
IEnumerable<string>? audiences = null,
IEnumerable<Claim>? additionalClaims = null)
@ -178,7 +184,7 @@ public class IdentityServerTools : IIdentityServerTools
claims.Add(new Claim(
JwtClaimTypes.Audience,
#pragma warning disable CA1863 // Would require changing a public const on a public class and be a breaking change
string.Format(CultureInfo.InvariantCulture, IdentityServerConstants.AccessTokenAudience, (await _issuerNameService.GetCurrentAsync(default)).EnsureTrailingSlash())));
string.Format(CultureInfo.InvariantCulture, IdentityServerConstants.AccessTokenAudience, (await _issuerNameService.GetCurrentAsync(ct)).EnsureTrailingSlash())));
#pragma warning restore CA1863
}
@ -190,6 +196,6 @@ public class IdentityServerTools : IIdentityServerTools
}
}
return await IssueJwtAsync(lifetime, claims);
return await IssueJwtAsync(lifetime, claims, ct);
}
}

View file

@ -103,7 +103,7 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService
// implementation doesn't make parallel use of a single DB context.
// Since the signing key material should be cached, only the
// first serial operation will call the db.
var payload = await CreateFormPostPayloadAsync(backChannelLogoutRequest);
var payload = await CreateFormPostPayloadAsync(backChannelLogoutRequest, ct);
logoutRequestsWithPayload.Add((backChannelLogoutRequest, payload));
}
@ -124,10 +124,11 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService
/// Creates the form-url-encoded payload (as a dictionary) to send to the client.
/// </summary>
/// <param name="request"></param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
protected async Task<Dictionary<string, string>> CreateFormPostPayloadAsync(BackChannelLogoutRequest request)
protected async Task<Dictionary<string, string>> CreateFormPostPayloadAsync(BackChannelLogoutRequest request, CT ct)
{
var token = await CreateTokenAsync(request);
var token = await CreateTokenAsync(request, ct);
var data = new Dictionary<string, string>
{
@ -140,8 +141,9 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService
/// Creates the JWT used for the back-channel logout notification.
/// </summary>
/// <param name="request"></param>
/// <param name="ct">The cancellation token.</param>
/// <returns>The token.</returns>
protected virtual async Task<string> CreateTokenAsync(BackChannelLogoutRequest request)
protected virtual async Task<string> CreateTokenAsync(BackChannelLogoutRequest request, CT ct)
{
var claims = await CreateClaimsForTokenAsync(request);
if (claims.Any(x => x.Type == JwtClaimTypes.Nonce))
@ -151,11 +153,11 @@ public class DefaultBackChannelLogoutService : IBackChannelLogoutService
if (request.Issuer != null)
{
return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, request.Issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims);
return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, request.Issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims, ct);
}
var issuer = await IssuerNameService.GetCurrentAsync(default);
return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims);
var issuer = await IssuerNameService.GetCurrentAsync(ct);
return await Tools.IssueJwtAsync(DefaultLogoutTokenLifetime, issuer, IdentityServerConstants.TokenTypes.LogoutToken, claims, ct);
}
/// <summary>

View file

@ -58,7 +58,7 @@ public class ServerSideTicketStore : IServerSideTicketStore
ArgumentNullException.ThrowIfNull(ticket);
ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(default));
ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(_httpContextAccessor.HttpContext?.RequestAborted ?? default));
var key = CryptoRandom.CreateUniqueId(format: CryptoRandom.OutputFormat.Hex);
@ -149,7 +149,7 @@ public class ServerSideTicketStore : IServerSideTicketStore
if (ticket.GetIssuer() == null)
{
// when issuing a new cookie on top of an existing cookie, the AuthenticationTicket passed above is new (and not the prior one loaded from the ticket store)
ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(default));
ticket.SetIssuer(await _issuerNameService.GetCurrentAsync(_httpContextAccessor.HttpContext?.RequestAborted ?? default));
}
session.Renewed = ticket.GetIssued();
session.Expires = ticket.GetExpiration();

View file

@ -21,6 +21,8 @@ public class CibaTests
{
private const string Category = "Backchannel Authentication (CIBA) endpoint";
private readonly CT _ct = TestContext.Current.CancellationToken;
private IdentityServerPipeline _mockPipeline = new();
private MockCibaUserValidator _mockCibaUserValidator = new();
private MockCibaUserNotificationService _mockCibaUserNotificationService = new();
@ -1513,7 +1515,7 @@ public class CibaTests
var id_token = await tokenService.IssueJwtAsync(600, new Claim[] {
new Claim("sub", _user.SubjectId),
new Claim("aud", _cibaClient.ClientId),
});
}, _ct);
var bindingMessage = Guid.NewGuid().ToString("n");
var body = new Dictionary<string, string>

View file

@ -17,6 +17,7 @@ namespace UnitTests.Services.Default;
public class DefaultBackChannelLogoutServiceTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
private class ServiceTestHarness : DefaultBackChannelLogoutService
{
public ServiceTestHarness(
@ -32,7 +33,7 @@ public class DefaultBackChannelLogoutServiceTests
// CreateTokenAsync is protected, so we use this wrapper to exercise it in our tests
public async Task<string> ExerciseCreateTokenAsync(BackChannelLogoutRequest request) => await CreateTokenAsync(request);
public async Task<string> ExerciseCreateTokenAsync(BackChannelLogoutRequest request, CT ct) => await CreateTokenAsync(request, ct);
}
[Fact]
@ -59,7 +60,7 @@ public class DefaultBackChannelLogoutServiceTests
{
ClientId = "test_client",
SubjectId = "test_sub",
});
}, _ct);
var payload = JsonSerializer.Deserialize<Dictionary<string, JsonElement>>(Base64Url.DecodeFromChars(rawToken.Split('.')[1]));