Add CT parameter to IReferenceTokenStore, flow through all implementations and tests

This commit is contained in:
Damian Hickey 2026-02-20 16:53:51 +01:00
parent 6b12c7ba92
commit fb05c70893
10 changed files with 54 additions and 50 deletions

View file

@ -102,14 +102,14 @@ public class TokenRevocationResponseGenerator : ITokenRevocationResponseGenerato
/// </summary>
protected virtual async Task<bool> RevokeAccessTokenAsync(TokenRevocationRequestValidationResult validationResult)
{
var token = await ReferenceTokenStore.GetReferenceTokenAsync(validationResult.Token);
var token = await ReferenceTokenStore.GetReferenceTokenAsync(validationResult.Token, default);
if (token != null)
{
if (token.ClientId == validationResult.Client.ClientId)
{
Logger.LogDebug("Access token revoked");
await ReferenceTokenStore.RemoveReferenceTokenAsync(validationResult.Token);
await ReferenceTokenStore.RemoveReferenceTokenAsync(validationResult.Token, default);
}
else
{
@ -135,7 +135,7 @@ public class TokenRevocationResponseGenerator : ITokenRevocationResponseGenerato
{
Logger.LogDebug("Refresh token revoked");
await RefreshTokenStore.RemoveRefreshTokenAsync(validationResult.Token);
await ReferenceTokenStore.RemoveReferenceTokensAsync(token.SubjectId, token.ClientId, token.SessionId);
await ReferenceTokenStore.RemoveReferenceTokensAsync(token.SubjectId, token.ClientId, token.SessionId, default);
}
else
{

View file

@ -256,7 +256,7 @@ public class DefaultTokenService : ITokenService
{
Logger.LogTrace("Creating reference access token");
var handle = await ReferenceTokenStore.StoreReferenceTokenAsync(token);
var handle = await ReferenceTokenStore.StoreReferenceTokenAsync(token, default);
tokenResult = handle;
}

View file

@ -31,34 +31,34 @@ public class DefaultReferenceTokenStore : DefaultGrantStore<Token>, IReferenceTo
}
/// <inheritdoc/>
public Task<string> StoreReferenceTokenAsync(Token token)
public Task<string> StoreReferenceTokenAsync(Token token, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.StoreReferenceToken");
return CreateItemAsync(token, token.ClientId, token.SubjectId, token.SessionId, token.Description, token.CreationTime, token.Lifetime, default);
return CreateItemAsync(token, token.ClientId, token.SubjectId, token.SessionId, token.Description, token.CreationTime, token.Lifetime, ct);
}
/// <inheritdoc/>
public Task<Token> GetReferenceTokenAsync(string handle)
public Task<Token> GetReferenceTokenAsync(string handle, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.GetReferenceToken");
return GetItemAsync(handle, default);
return GetItemAsync(handle, ct);
}
/// <inheritdoc/>
public Task RemoveReferenceTokenAsync(string handle)
public Task RemoveReferenceTokenAsync(string handle, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.RemoveReferenceToken");
return RemoveItemAsync(handle, default);
return RemoveItemAsync(handle, ct);
}
/// <inheritdoc/>
public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId = null)
public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DefaultReferenceTokenStore.RemoveReferenceTokens");
return RemoveAllAsync(subjectId, clientId, sessionId);
return RemoveAllAsync(subjectId, clientId, sessionId, ct);
}
}

View file

@ -371,7 +371,7 @@ internal class TokenValidator : ITokenValidator
using var activity = Tracing.BasicActivitySource.StartActivity("TokenValidator.ValidateReferenceAccessToken");
_log.TokenHandle = tokenHandle;
var token = await _referenceTokenStore.GetReferenceTokenAsync(tokenHandle);
var token = await _referenceTokenStore.GetReferenceTokenAsync(tokenHandle, ct);
if (token == null)
{
@ -383,7 +383,7 @@ internal class TokenValidator : ITokenValidator
{
LogError("Token expired.");
await _referenceTokenStore.RemoveReferenceTokenAsync(tokenHandle);
await _referenceTokenStore.RemoveReferenceTokenAsync(tokenHandle, ct);
return Invalid(OidcConstants.ProtectedResourceErrors.ExpiredToken);
}

View file

@ -17,22 +17,25 @@ public interface IReferenceTokenStore
/// Stores the reference token.
/// </summary>
/// <param name="token">The token.</param>
/// <param name="ct">The <see cref="CT"/> used to propagate notifications that the operation should be canceled.</param>
/// <returns></returns>
Task<string> StoreReferenceTokenAsync(Token token);
Task<string> StoreReferenceTokenAsync(Token token, CT ct);
/// <summary>
/// Gets the reference token.
/// </summary>
/// <param name="handle">The handle.</param>
/// <param name="ct">The <see cref="CT"/> used to propagate notifications that the operation should be canceled.</param>
/// <returns></returns>
Task<Token?> GetReferenceTokenAsync(string handle);
Task<Token?> GetReferenceTokenAsync(string handle, CT ct);
/// <summary>
/// Removes the reference token.
/// </summary>
/// <param name="handle">The handle.</param>
/// <param name="ct">The <see cref="CT"/> used to propagate notifications that the operation should be canceled.</param>
/// <returns></returns>
Task RemoveReferenceTokenAsync(string handle);
Task RemoveReferenceTokenAsync(string handle, CT ct);
/// <summary>
/// Removes the reference tokens.
@ -40,6 +43,7 @@ public interface IReferenceTokenStore
/// <param name="subjectId">The subject identifier.</param>
/// <param name="clientId">The client identifier.</param>
/// <param name="sessionId">The session identifier.</param>
/// <param name="ct">The <see cref="CT"/> used to propagate notifications that the operation should be canceled.</param>
/// <returns></returns>
Task RemoveReferenceTokensAsync(string subjectId, string clientId, string? sessionId = null);
Task RemoveReferenceTokensAsync(string subjectId, string clientId, string? sessionId, CT ct);
}

View file

@ -9,11 +9,11 @@ namespace UnitTests.Common;
internal class MockReferenceTokenStore : IReferenceTokenStore
{
public Task<Token> GetReferenceTokenAsync(string handle) => throw new NotImplementedException();
public Task<Token> GetReferenceTokenAsync(string handle, CT ct) => throw new NotImplementedException();
public Task RemoveReferenceTokenAsync(string handle) => throw new NotImplementedException();
public Task RemoveReferenceTokenAsync(string handle, CT ct) => throw new NotImplementedException();
public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId = null) => throw new NotImplementedException();
public Task RemoveReferenceTokensAsync(string subjectId, string clientId, string sessionId, CT ct) => throw new NotImplementedException();
public Task<string> StoreReferenceTokenAsync(Token token) => throw new NotImplementedException();
public Task<string> StoreReferenceTokenAsync(Token token, CT ct) => throw new NotImplementedException();
}

View file

@ -85,7 +85,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("scope", "bar1"),
new Claim("scope", "bar2")
}
});
}, _ct);
var handle2 = await _referenceTokens.StoreReferenceTokenAsync(new Token()
{
@ -98,7 +98,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("sub", "123"),
new Claim("scope", "bar3")
}
});
}, _ct);
var handle3 = await _referenceTokens.StoreReferenceTokenAsync(new Token()
{
@ -111,7 +111,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("sub", "456"),
new Claim("scope", "bar3")
}
});
}, _ct);
var handle4 = await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken()
{
@ -223,7 +223,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("scope", "bar1"),
new Claim("scope", "bar2")
}
});
}, _ct);
var handle2 = await _referenceTokens.StoreReferenceTokenAsync(new Token()
{
@ -237,7 +237,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("sub", "123"),
new Claim("scope", "bar3")
}
});
}, _ct);
var handle3 = await _referenceTokens.StoreReferenceTokenAsync(new Token()
{
@ -251,7 +251,7 @@ public class DefaultPersistedGrantServiceTests
new Claim("sub", "456"),
new Claim("scope", "bar3")
}
});
}, _ct);
var handle4 = await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken()
{
@ -316,9 +316,9 @@ public class DefaultPersistedGrantServiceTests
await _subject.RemoveAllGrantsAsync("123", "client1");
(await _referenceTokens.GetReferenceTokenAsync(handle1)).ShouldBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle2)).ShouldNotBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle3)).ShouldNotBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle1, _ct)).ShouldBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle2, _ct)).ShouldNotBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle3, _ct)).ShouldNotBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle4)).ShouldBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle5)).ShouldNotBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle6)).ShouldNotBeNull();

View file

@ -233,8 +233,8 @@ public class DefaultPersistedGrantStoreTests
Version = 1
};
var handle = await _referenceTokens.StoreReferenceTokenAsync(token1);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle);
var handle = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle, _ct);
token1.ClientId.ShouldBe(token2.ClientId);
token1.Audiences.Count.ShouldBe(1);
@ -262,9 +262,9 @@ public class DefaultPersistedGrantStoreTests
Version = 1
};
var handle = await _referenceTokens.StoreReferenceTokenAsync(token1);
await _referenceTokens.RemoveReferenceTokenAsync(handle);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle);
var handle = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct);
await _referenceTokens.RemoveReferenceTokenAsync(handle, _ct);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle, _ct);
token2.ShouldBeNull();
}
@ -285,13 +285,13 @@ public class DefaultPersistedGrantStoreTests
Version = 1
};
var handle1 = await _referenceTokens.StoreReferenceTokenAsync(token1);
var handle2 = await _referenceTokens.StoreReferenceTokenAsync(token1);
await _referenceTokens.RemoveReferenceTokensAsync("123", "client");
var handle1 = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct);
var handle2 = await _referenceTokens.StoreReferenceTokenAsync(token1, _ct);
await _referenceTokens.RemoveReferenceTokensAsync("123", "client", null, _ct);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle1);
var token2 = await _referenceTokens.GetReferenceTokenAsync(handle1, _ct);
token2.ShouldBeNull();
token2 = await _referenceTokens.GetReferenceTokenAsync(handle2);
token2 = await _referenceTokens.GetReferenceTokenAsync(handle2, _ct);
token2.ShouldBeNull();
}
@ -349,7 +349,7 @@ public class DefaultPersistedGrantStoreTests
new Claim("scope", "bar1"),
new Claim("scope", "bar2")
}
});
}, _ct);
await _refreshTokens.StoreRefreshTokenAsync(new RefreshToken()
{
@ -374,6 +374,6 @@ public class DefaultPersistedGrantStoreTests
// the -1 is needed because internally we append a version/suffix the handle for encoding
(await _codes.GetAuthorizationCodeAsync("key-1", _ct)).Lifetime.ShouldBe(30);
(await _refreshTokens.GetRefreshTokenAsync("key-1")).Lifetime.ShouldBe(20);
(await _referenceTokens.GetReferenceTokenAsync("key-1")).Lifetime.ShouldBe(10);
(await _referenceTokens.GetReferenceTokenAsync("key-1", _ct)).Lifetime.ShouldBe(10);
}
}

View file

@ -49,7 +49,7 @@ public class AccessTokenValidation
var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write");
var handle = await store.StoreReferenceTokenAsync(token);
var handle = await store.StoreReferenceTokenAsync(token, _ct);
var result = await validator.ValidateAccessTokenAsync(handle, null, _ct);
@ -73,7 +73,7 @@ public class AccessTokenValidation
var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write");
var handle = await store.StoreReferenceTokenAsync(token);
var handle = await store.StoreReferenceTokenAsync(token, _ct);
var result = await validator.ValidateAccessTokenAsync(handle, "read", _ct);
@ -89,7 +89,7 @@ public class AccessTokenValidation
var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 600, "read", "write");
var handle = await store.StoreReferenceTokenAsync(token);
var handle = await store.StoreReferenceTokenAsync(token, _ct);
var result = await validator.ValidateAccessTokenAsync(handle, "missing", _ct);
@ -135,7 +135,7 @@ public class AccessTokenValidation
var token = TokenFactory.CreateAccessToken(new Client { ClientId = "roclient" }, "valid", 2, "read", "write");
token.CreationTime = now;
var handle = await store.StoreReferenceTokenAsync(token);
var handle = await store.StoreReferenceTokenAsync(token, _ct);
now = now.AddSeconds(3);
_timeProvider.SetUtcNow(now);
@ -292,7 +292,7 @@ public class AccessTokenValidation
var token = TokenFactory.CreateAccessToken(new Client { ClientId = "unknown" }, "valid", 600, "read", "write");
var handle = await store.StoreReferenceTokenAsync(token);
var handle = await store.StoreReferenceTokenAsync(token, _ct);
var result = await validator.ValidateAccessTokenAsync(handle, null, _ct);

View file

@ -45,7 +45,7 @@ public class IntrospectionRequestValidatorTests
new System.Security.Claims.Claim("scope", "b")
}
};
var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token);
var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token, _ct);
var param = new NameValueCollection()
{
@ -135,7 +135,7 @@ public class IntrospectionRequestValidatorTests
}
};
var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token);
var handle = await _referenceTokenStore.StoreReferenceTokenAsync(token, _ct);
var param = new NameValueCollection
{
{ "token", handle }