From fb05c70893a6b8bc4dccb50537813e9e4438abbf Mon Sep 17 00:00:00 2001 From: Damian Hickey Date: Fri, 20 Feb 2026 16:53:51 +0100 Subject: [PATCH] Add CT parameter to IReferenceTokenStore, flow through all implementations and tests --- .../TokenRevocationResponseGenerator.cs | 6 ++--- .../Services/Default/DefaultTokenService.cs | 2 +- .../Default/DefaultReferenceTokenStore.cs | 16 ++++++------- .../Validation/Default/TokenValidator.cs | 4 ++-- .../Storage/Stores/IReferenceTokenStore.cs | 12 ++++++---- .../Common/MockReferenceTokenStore.cs | 8 +++---- .../DefaultPersistedGrantServiceTests.cs | 18 +++++++------- .../DefaultPersistedGrantStoreTests.cs | 24 +++++++++---------- .../Validation/AccessTokenValidation.cs | 10 ++++---- .../IntrospectionRequestValidatorTests.cs | 4 ++-- 10 files changed, 54 insertions(+), 50 deletions(-) diff --git a/identity-server/src/IdentityServer/ResponseHandling/Default/TokenRevocationResponseGenerator.cs b/identity-server/src/IdentityServer/ResponseHandling/Default/TokenRevocationResponseGenerator.cs index faa1197da..76d892c13 100644 --- a/identity-server/src/IdentityServer/ResponseHandling/Default/TokenRevocationResponseGenerator.cs +++ b/identity-server/src/IdentityServer/ResponseHandling/Default/TokenRevocationResponseGenerator.cs @@ -102,14 +102,14 @@ public class TokenRevocationResponseGenerator : ITokenRevocationResponseGenerato /// protected virtual async Task RevokeAccessTokenAsync(TokenRevocationRequestValidationResult validationResult) { - var token = await ReferenceTokenStore.GetReferenceTokenAsync(validationResult.Token); + var token = await ReferenceTokenStore.GetReferenceTokenAsync(validationResult.Token, default); if (token != null) { if (token.ClientId == validationResult.Client.ClientId) { Logger.LogDebug("Access token revoked"); - await ReferenceTokenStore.RemoveReferenceTokenAsync(validationResult.Token); + await ReferenceTokenStore.RemoveReferenceTokenAsync(validationResult.Token, default); } else { @@ -135,7 +135,7 @@ public class TokenRevocationResponseGenerator : ITokenRevocationResponseGenerato { Logger.LogDebug("Refresh token revoked"); await RefreshTokenStore.RemoveRefreshTokenAsync(validationResult.Token); - await ReferenceTokenStore.RemoveReferenceTokensAsync(token.SubjectId, token.ClientId, token.SessionId); + await ReferenceTokenStore.RemoveReferenceTokensAsync(token.SubjectId, token.ClientId, token.SessionId, default); } else { diff --git a/identity-server/src/IdentityServer/Services/Default/DefaultTokenService.cs b/identity-server/src/IdentityServer/Services/Default/DefaultTokenService.cs index 0d7c299ae..6bc0be3fe 100644 --- a/identity-server/src/IdentityServer/Services/Default/DefaultTokenService.cs +++ b/identity-server/src/IdentityServer/Services/Default/DefaultTokenService.cs @@ -256,7 +256,7 @@ public class DefaultTokenService : ITokenService { Logger.LogTrace("Creating reference access token"); - var handle = await ReferenceTokenStore.StoreReferenceTokenAsync(token); + var handle = await ReferenceTokenStore.StoreReferenceTokenAsync(token, default); tokenResult = handle; } diff --git a/identity-server/src/IdentityServer/Stores/Default/DefaultReferenceTokenStore.cs b/identity-server/src/IdentityServer/Stores/Default/DefaultReferenceTokenStore.cs index f2d6434d2..c0715cf29 100644 --- a/identity-server/src/IdentityServer/Stores/Default/DefaultReferenceTokenStore.cs +++ b/identity-server/src/IdentityServer/Stores/Default/DefaultReferenceTokenStore.cs @@ -31,34 +31,34 @@ public class DefaultReferenceTokenStore : DefaultGrantStore, IReferenceTo } /// - public Task StoreReferenceTokenAsync(Token token) + public Task StoreReferenceTokenAsync(Token token, CT ct) { using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.StoreReferenceToken"); - return CreateItemAsync(token, token.ClientId, token.SubjectId, token.SessionId, token.Description, token.CreationTime, token.Lifetime, default); + return CreateItemAsync(token, token.ClientId, token.SubjectId, token.SessionId, token.Description, token.CreationTime, token.Lifetime, ct); } /// - public Task GetReferenceTokenAsync(string handle) + public Task GetReferenceTokenAsync(string handle, CT ct) { using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.GetReferenceToken"); - return GetItemAsync(handle, default); + return GetItemAsync(handle, ct); } /// - public Task RemoveReferenceTokenAsync(string handle) + public Task RemoveReferenceTokenAsync(string handle, CT ct) { using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.RemoveReferenceToken"); - return RemoveItemAsync(handle, default); + return RemoveItemAsync(handle, ct); } /// - public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId = null) + public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId, CT ct) { using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.RemoveReferenceTokens"); - return RemoveAllAsync(subjectId, clientId, sessionId); + return RemoveAllAsync(subjectId, clientId, sessionId, ct); } } diff --git a/identity-server/src/IdentityServer/Validation/Default/TokenValidator.cs b/identity-server/src/IdentityServer/Validation/Default/TokenValidator.cs index aa3f791e2..647fa606c 100644 --- a/identity-server/src/IdentityServer/Validation/Default/TokenValidator.cs +++ b/identity-server/src/IdentityServer/Validation/Default/TokenValidator.cs @@ -371,7 +371,7 @@ internal class TokenValidator : ITokenValidator using var activity = Tracing.BasicActivitySource.StartActivity("TokenValidator.ValidateReferenceAccessToken"); _log.TokenHandle = tokenHandle; - var token = await _referenceTokenStore.GetReferenceTokenAsync(tokenHandle); + var token = await _referenceTokenStore.GetReferenceTokenAsync(tokenHandle, ct); if (token == null) { @@ -383,7 +383,7 @@ internal class TokenValidator : ITokenValidator { LogError("Token expired."); - await _referenceTokenStore.RemoveReferenceTokenAsync(tokenHandle); + await _referenceTokenStore.RemoveReferenceTokenAsync(tokenHandle, ct); return Invalid(OidcConstants.ProtectedResourceErrors.ExpiredToken); } diff --git a/identity-server/src/Storage/Stores/IReferenceTokenStore.cs b/identity-server/src/Storage/Stores/IReferenceTokenStore.cs index 32015c398..32fae9af1 100644 --- a/identity-server/src/Storage/Stores/IReferenceTokenStore.cs +++ b/identity-server/src/Storage/Stores/IReferenceTokenStore.cs @@ -17,22 +17,25 @@ public interface IReferenceTokenStore /// Stores the reference token. /// /// The token. + /// The used to propagate notifications that the operation should be canceled. /// - Task StoreReferenceTokenAsync(Token token); + Task StoreReferenceTokenAsync(Token token, CT ct); /// /// Gets the reference token. /// /// The handle. + /// The used to propagate notifications that the operation should be canceled. /// - Task GetReferenceTokenAsync(string handle); + Task GetReferenceTokenAsync(string handle, CT ct); /// /// Removes the reference token. /// /// The handle. + /// The used to propagate notifications that the operation should be canceled. /// - Task RemoveReferenceTokenAsync(string handle); + Task RemoveReferenceTokenAsync(string handle, CT ct); /// /// Removes the reference tokens. @@ -40,6 +43,7 @@ public interface IReferenceTokenStore /// The subject identifier. /// The client identifier. /// The session identifier. + /// The used to propagate notifications that the operation should be canceled. /// - Task RemoveReferenceTokensAsync(string subjectId, string clientId, string? sessionId = null); + Task RemoveReferenceTokensAsync(string subjectId, string clientId, string? sessionId, CT ct); } diff --git a/identity-server/test/IdentityServer.UnitTests/Common/MockReferenceTokenStore.cs b/identity-server/test/IdentityServer.UnitTests/Common/MockReferenceTokenStore.cs index 3579fa01d..504226547 100644 --- a/identity-server/test/IdentityServer.UnitTests/Common/MockReferenceTokenStore.cs +++ b/identity-server/test/IdentityServer.UnitTests/Common/MockReferenceTokenStore.cs @@ -9,11 +9,11 @@ namespace UnitTests.Common; internal class MockReferenceTokenStore : IReferenceTokenStore { - public Task GetReferenceTokenAsync(string handle) => throw new NotImplementedException(); + public Task GetReferenceTokenAsync(string handle, CT ct) => throw new NotImplementedException(); - public Task RemoveReferenceTokenAsync(string handle) => throw new NotImplementedException(); + public Task RemoveReferenceTokenAsync(string handle, CT ct) => throw new NotImplementedException(); - public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId = null) => throw new NotImplementedException(); + public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId, CT ct) => throw new NotImplementedException(); - public Task StoreReferenceTokenAsync(Token token) => throw new NotImplementedException(); + public Task StoreReferenceTokenAsync(Token token, CT ct) => throw new NotImplementedException(); } diff --git a/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs b/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs index 8762cc2a8..59bd89b97 100644 --- a/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs +++ b/identity-server/test/IdentityServer.UnitTests/Services/Default/DefaultPersistedGrantServiceTests.cs @@ -85,7 +85,7 @@ public class DefaultPersistedGrantServiceTests new Claim("scope", "bar1"), new Claim("scope", "bar2") } - }); + }, _ct); var handle2 = await _referenceTokens.StoreReferenceTokenAsync(new Token() { @@ -98,7 +98,7 @@ public class DefaultPersistedGrantServiceTests new Claim("sub", "123"), new Claim("scope", "bar3") } - }); + }, _ct); var handle3 = await _referenceTokens.StoreReferenceTokenAsync(new Token() { @@ -111,7 +111,7 @@ public class DefaultPersistedGrantServiceTests new Claim("sub", "456"), new Claim("scope", "bar3") } - }); + }, _ct); var handle4 = await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken() { @@ -223,7 +223,7 @@ public class DefaultPersistedGrantServiceTests new Claim("scope", "bar1"), new Claim("scope", "bar2") } - }); + }, _ct); var handle2 = await _referenceTokens.StoreReferenceTokenAsync(new Token() { @@ -237,7 +237,7 @@ public class DefaultPersistedGrantServiceTests new Claim("sub", "123"), new Claim("scope", "bar3") } - }); + }, _ct); var handle3 = await _referenceTokens.StoreReferenceTokenAsync(new Token() { @@ -251,7 +251,7 @@ public class DefaultPersistedGrantServiceTests new Claim("sub", "456"), new Claim("scope", "bar3") } - }); + }, _ct); var handle4 = await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken() { @@ -316,9 +316,9 @@ public class DefaultPersistedGrantServiceTests await _subject.RemoveAllGrantsAsync("123", "client1"); - (await _referenceTokens.GetReferenceTokenAsync(handle1)).ShouldBeNull(); - (await _referenceTokens.GetReferenceTokenAsync(handle2)).ShouldNotBeNull(); - (await _referenceTokens.GetReferenceTokenAsync(handle3)).ShouldNotBeNull(); + (await _referenceTokens.GetReferenceTokenAsync(handle1, _ct)).ShouldBeNull(); + (await _referenceTokens.GetReferenceTokenAsync(handle2, _ct)).ShouldNotBeNull(); + (await _referenceTokens.GetReferenceTokenAsync(handle3, _ct)).ShouldNotBeNull(); (await _refreshTokens.GetRefreshTokenAsync(handle4)).ShouldBeNull(); (await _refreshTokens.GetRefreshTokenAsync(handle5)).ShouldNotBeNull(); (await _refreshTokens.GetRefreshTokenAsync(handle6)).ShouldNotBeNull(); diff --git a/identity-server/test/IdentityServer.UnitTests/Stores/Default/DefaultPersistedGrantStoreTests.cs b/identity-server/test/IdentityServer.UnitTests/Stores/Default/DefaultPersistedGrantStoreTests.cs index 943d4bd4a..657cc044d 100644 --- a/identity-server/test/IdentityServer.UnitTests/Stores/Default/DefaultPersistedGrantStoreTests.cs +++ b/identity-server/test/IdentityServer.UnitTests/Stores/Default/DefaultPersistedGrantStoreTests.cs @@ -233,8 +233,8 @@ public class DefaultPersistedGrantStoreTests Version = 1 }; - var handle = await _referenceTokens.StoreReferenceTokenAsync(token1); - var token2 = await _referenceTokens.GetReferenceTokenAsync(handle); + var handle = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct); + var token2 = await _referenceTokens.GetReferenceTokenAsync(handle, _ct); token1.ClientId.ShouldBe(token2.ClientId); token1.Audiences.Count.ShouldBe(1); @@ -262,9 +262,9 @@ public class DefaultPersistedGrantStoreTests Version = 1 }; - var handle = await _referenceTokens.StoreReferenceTokenAsync(token1); - await _referenceTokens.RemoveReferenceTokenAsync(handle); - var token2 = await _referenceTokens.GetReferenceTokenAsync(handle); + var handle = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct); + await _referenceTokens.RemoveReferenceTokenAsync(handle, _ct); + var token2 = await _referenceTokens.GetReferenceTokenAsync(handle, _ct); token2.ShouldBeNull(); } @@ -285,13 +285,13 @@ public class DefaultPersistedGrantStoreTests Version = 1 }; - var handle1 = await _referenceTokens.StoreReferenceTokenAsync(token1); - var handle2 = await _referenceTokens.StoreReferenceTokenAsync(token1); - await _referenceTokens.RemoveReferenceTokensAsync("123", "client"); + var handle1 = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct); + var handle2 = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct); + await _referenceTokens.RemoveReferenceTokensAsync("123", "client", null, _ct); - var token2 = await _referenceTokens.GetReferenceTokenAsync(handle1); + var token2 = await _referenceTokens.GetReferenceTokenAsync(handle1, _ct); token2.ShouldBeNull(); - token2 = await _referenceTokens.GetReferenceTokenAsync(handle2); + token2 = await _referenceTokens.GetReferenceTokenAsync(handle2, _ct); token2.ShouldBeNull(); } @@ -349,7 +349,7 @@ public class DefaultPersistedGrantStoreTests new Claim("scope", "bar1"), new Claim("scope", "bar2") } - }); + }, _ct); await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken() { @@ -374,6 +374,6 @@ public class DefaultPersistedGrantStoreTests // the -1 is needed because internally we append a version/suffix the handle for encoding (await _codes.GetAuthorizationCodeAsync("key-1", _ct)).Lifetime.ShouldBe(30); (await _refreshTokens.GetRefreshTokenAsync("key-1")).Lifetime.ShouldBe(20); - (await _referenceTokens.GetReferenceTokenAsync("key-1")).Lifetime.ShouldBe(10); + (await _referenceTokens.GetReferenceTokenAsync("key-1", _ct)).Lifetime.ShouldBe(10); } } diff --git a/identity-server/test/IdentityServer.UnitTests/Validation/AccessTokenValidation.cs b/identity-server/test/IdentityServer.UnitTests/Validation/AccessTokenValidation.cs index ae051ff7c..201b7404d 100644 --- a/identity-server/test/IdentityServer.UnitTests/Validation/AccessTokenValidation.cs +++ b/identity-server/test/IdentityServer.UnitTests/Validation/AccessTokenValidation.cs @@ -49,7 +49,7 @@ public class AccessTokenValidation var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write"); - var handle = await store.StoreReferenceTokenAsync(token); + var handle = await store.StoreReferenceTokenAsync(token, _ct); var result = await validator.ValidateAccessTokenAsync(handle, null, _ct); @@ -73,7 +73,7 @@ public class AccessTokenValidation var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write"); - var handle = await store.StoreReferenceTokenAsync(token); + var handle = await store.StoreReferenceTokenAsync(token, _ct); var result = await validator.ValidateAccessTokenAsync(handle, "read", _ct); @@ -89,7 +89,7 @@ public class AccessTokenValidation var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write"); - var handle = await store.StoreReferenceTokenAsync(token); + var handle = await store.StoreReferenceTokenAsync(token, _ct); var result = await validator.ValidateAccessTokenAsync(handle, "missing", _ct); @@ -135,7 +135,7 @@ public class AccessTokenValidation var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 2, "read", "write"); token.CreationTime = now; - var handle = await store.StoreReferenceTokenAsync(token); + var handle = await store.StoreReferenceTokenAsync(token, _ct); now = now.AddSeconds(3); _timeProvider.SetUtcNow(now); @@ -292,7 +292,7 @@ public class AccessTokenValidation var token = TokenFactory.CreateAccessToken(new Client { ClientId = "unknown" }, "valid", 600, "read", "write"); - var handle = await store.StoreReferenceTokenAsync(token); + var handle = await store.StoreReferenceTokenAsync(token, _ct); var result = await validator.ValidateAccessTokenAsync(handle, null, _ct); diff --git a/identity-server/test/IdentityServer.UnitTests/Validation/IntrospectionRequestValidatorTests.cs b/identity-server/test/IdentityServer.UnitTests/Validation/IntrospectionRequestValidatorTests.cs index ca6c30aef..24a8ecb6e 100644 --- a/identity-server/test/IdentityServer.UnitTests/Validation/IntrospectionRequestValidatorTests.cs +++ b/identity-server/test/IdentityServer.UnitTests/Validation/IntrospectionRequestValidatorTests.cs @@ -45,7 +45,7 @@ public class IntrospectionRequestValidatorTests new System.Security.Claims.Claim("scope", "b") } }; - var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token); + var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token, _ct); var param = new NameValueCollection() { @@ -135,7 +135,7 @@ public class IntrospectionRequestValidatorTests } }; - var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token); + var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token, _ct); var param = new NameValueCollection { { "token", handle }