Make CT required in IMessageStore, IConsentMessageStore, IAuthorizationParametersMessageStore, flow through implementations and tests

This commit is contained in:
Damian Hickey 2026-02-20 20:35:29 +01:00
parent 907e07f618
commit 10ba98cca0
19 changed files with 54 additions and 44 deletions

View file

@ -77,10 +77,10 @@ internal abstract class AuthorizeEndpointBase : IEndpointHandler
if (checkConsentResponse && _authorizationParametersMessageStore != null)
{
var messageStoreId = parameters[Constants.AuthorizationParamsStore.MessageStoreIdParameterName];
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId);
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId, ct);
parameters = entry?.Data.FromFullDictionary() ?? new NameValueCollection();
await _authorizationParametersMessageStore.DeleteAsync(messageStoreId);
await _authorizationParametersMessageStore.DeleteAsync(messageStoreId, ct);
}
// validate request
@ -105,7 +105,7 @@ internal abstract class AuthorizeEndpointBase : IEndpointHandler
{
var consentRequest = new ConsentRequest(result.ValidatedRequest.Raw, user?.GetSubjectId());
consentRequestId = consentRequest.Id;
consent = await _consentResponseStore.ReadAsync(consentRequestId);
consent = await _consentResponseStore.ReadAsync(consentRequestId, ct);
if (consent != null && consent.Data == null)
{
@ -155,7 +155,7 @@ internal abstract class AuthorizeEndpointBase : IEndpointHandler
{
if (consentRequestId != null)
{
await _consentResponseStore.DeleteAsync(consentRequestId);
await _consentResponseStore.DeleteAsync(consentRequestId, ct);
}
}
}

View file

@ -84,7 +84,7 @@ internal class AuthorizeInteractionPageHttpWriter : IHttpResponseWriter<Authoriz
#pragma warning disable CS0618 // Type or member is obsolete
var msg = new Message<IDictionary<string, string[]>>(result.Request.ToOptimizedFullDictionary());
#pragma warning restore CS0618 // Type or member is obsolete
var id = await _authorizationParametersMessageStore.WriteAsync(msg);
var id = await _authorizationParametersMessageStore.WriteAsync(msg, context.RequestAborted);
returnUrl = returnUrl.AddQueryString(Constants.AuthorizationParamsStore.MessageStoreIdParameterName, id);
}
else

View file

@ -227,7 +227,7 @@ public class AuthorizeHttpWriter : IHttpResponseWriter<AuthorizeResult>
var errorModel = await CreateErrorMessage(response, context);
var message = new Message<ErrorMessage>(errorModel, _timeProvider.GetUtcNow().UtcDateTime);
var id = await _errorMessageStore.WriteAsync(message);
var id = await _errorMessageStore.WriteAsync(message, context.RequestAborted);
var errorUrl = _options.UserInteraction.ErrorUrl;

View file

@ -66,7 +66,7 @@ internal class EndSessionHttpWriter : IHttpResponseWriter<EndSessionResult>
if (logoutMessage.ContainsPayload)
{
var msg = new Message<LogoutMessage>(logoutMessage, _timeProvider.GetUtcNow().UtcDateTime);
id = await _logoutMessageStore.WriteAsync(msg);
id = await _logoutMessageStore.WriteAsync(msg, context.RequestAborted);
}
}

View file

@ -98,7 +98,7 @@ public static class HttpContextExtensions
var msg = new Message<LogoutNotificationContext>(endSessionMsg, timeProvider.GetUtcNow().UtcDateTime);
var endSessionMessageStore = context.RequestServices.GetRequiredService<IMessageStore<LogoutNotificationContext>>();
var id = await endSessionMessageStore.WriteAsync(msg);
var id = await endSessionMessageStore.WriteAsync(msg, context.RequestAborted);
var urls = context.RequestServices.GetRequiredService<IServerUrls>();
var signoutIframeUrl = urls.BaseUrl.EnsureTrailingSlash() + ProtocolRoutePaths.EndSessionCallback;

View file

@ -66,7 +66,7 @@ internal class DefaultIdentityServerInteractionService : IIdentityServerInteract
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultIdentityServerInteractionService.GetLogoutContext");
var msg = await _logoutMessageStore.ReadAsync(logoutId);
var msg = await _logoutMessageStore.ReadAsync(logoutId, default);
var iframeUrl = await _context.HttpContext.GetIdentityServerSignoutFrameCallbackUrlAsync(msg?.Data);
return new LogoutRequest(iframeUrl, msg?.Data);
}
@ -88,7 +88,7 @@ internal class DefaultIdentityServerInteractionService : IIdentityServerInteract
SessionId = sid,
ClientIds = clientIds
}, _timeProvider.GetUtcNow().UtcDateTime);
var id = await _logoutMessageStore.WriteAsync(msg);
var id = await _logoutMessageStore.WriteAsync(msg, default);
return id;
}
}
@ -102,7 +102,7 @@ internal class DefaultIdentityServerInteractionService : IIdentityServerInteract
if (errorId != null)
{
var result = await _errorMessageStore.ReadAsync(errorId);
var result = await _errorMessageStore.ReadAsync(errorId, default);
var data = result?.Data;
if (data != null)
{
@ -136,7 +136,7 @@ internal class DefaultIdentityServerInteractionService : IIdentityServerInteract
}
var consentRequest = new ConsentRequest(request, subject);
await _consentMessageStore.WriteAsync(consentRequest.Id, new Message<ConsentResponse>(consent, _timeProvider.GetUtcNow().UtcDateTime));
await _consentMessageStore.WriteAsync(consentRequest.Id, new Message<ConsentResponse>(consent, _timeProvider.GetUtcNow().UtcDateTime), default);
}
public Task DenyAuthorizationAsync(AuthorizationRequest request, AuthorizationError error, string errorDescription = null)

View file

@ -48,7 +48,7 @@ internal class OidcReturnUrlParser : IReturnUrlParser
if (_authorizationParametersMessageStore != null)
{
var messageStoreId = parameters[Constants.AuthorizationParamsStore.MessageStoreIdParameterName];
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId);
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId, default);
parameters = entry?.Data.FromFullDictionary() ?? new NameValueCollection();
}

View file

@ -12,7 +12,7 @@ internal class ConsentMessageStore : IConsentMessageStore
public ConsentMessageStore(MessageCookie<ConsentResponse> cookie) => Cookie = cookie;
public virtual Task DeleteAsync(string id)
public virtual Task DeleteAsync(string id, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ConsentMessageStore.Delete");
@ -20,14 +20,14 @@ internal class ConsentMessageStore : IConsentMessageStore
return Task.CompletedTask;
}
public virtual Task<Message<ConsentResponse>> ReadAsync(string id)
public virtual Task<Message<ConsentResponse>> ReadAsync(string id, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ConsentMessageStore.Read");
return Task.FromResult(Cookie.Read(id));
}
public virtual Task WriteAsync(string id, Message<ConsentResponse> message)
public virtual Task WriteAsync(string id, Message<ConsentResponse> message, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ConsentMessageStore.Write");

View file

@ -31,7 +31,7 @@ public class DistributedCacheAuthorizationParametersMessageStore : IAuthorizatio
private static string CacheKeyPrefix => "DistributedCacheAuthorizationParametersMessageStore";
/// <inheritdoc/>
public virtual async Task<string> WriteAsync(Message<IDictionary<string, string[]>> message)
public virtual async Task<string> WriteAsync(Message<IDictionary<string, string[]>> message, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DistributedCacheAuthorizationParametersMessageStore.Write");
@ -49,18 +49,18 @@ public class DistributedCacheAuthorizationParametersMessageStore : IAuthorizatio
var options = new DistributedCacheEntryOptions();
options.SetSlidingExpiration(Constants.DefaultCacheDuration);
await _distributedCache.SetStringAsync(cacheKey, json, options);
await _distributedCache.SetStringAsync(cacheKey, json, options, ct);
return key;
}
/// <inheritdoc/>
public virtual async Task<Message<IDictionary<string, string[]>>> ReadAsync(string id)
public virtual async Task<Message<IDictionary<string, string[]>>> ReadAsync(string id, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DistributedCacheAuthorizationParametersMessageStore.Read");
var cacheKey = $"{CacheKeyPrefix}-{id}";
var json = await _distributedCache.GetStringAsync(cacheKey);
var json = await _distributedCache.GetStringAsync(cacheKey, ct);
if (json == null)
{
@ -71,11 +71,11 @@ public class DistributedCacheAuthorizationParametersMessageStore : IAuthorizatio
}
/// <inheritdoc/>
public virtual Task DeleteAsync(string id)
public virtual Task DeleteAsync(string id, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("DistributedCacheAuthorizationParametersMessageStore.Delete");
var cacheKey = $"{CacheKeyPrefix}-{id}";
return _distributedCache.RemoveAsync(cacheKey);
return _distributedCache.RemoveAsync(cacheKey, ct);
}
}

View file

@ -40,7 +40,7 @@ public class ProtectedDataMessageStore<TModel> : IMessageStore<TModel>
}
/// <inheritdoc />
public virtual Task<Message<TModel>> ReadAsync(string value)
public virtual Task<Message<TModel>> ReadAsync(string value, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ProtectedDataMessageStore.Read");
@ -65,7 +65,7 @@ public class ProtectedDataMessageStore<TModel> : IMessageStore<TModel>
}
/// <inheritdoc />
public virtual Task<string> WriteAsync(Message<TModel> message)
public virtual Task<string> WriteAsync(Message<TModel> message, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ProtectedDataMessageStore.Write");

View file

@ -10,18 +10,18 @@ namespace Duende.IdentityServer.Stores;
// internal just for testing
internal class QueryStringAuthorizationParametersMessageStore : IAuthorizationParametersMessageStore
{
public Task<string> WriteAsync(Message<IDictionary<string, string[]>> message)
public Task<string> WriteAsync(Message<IDictionary<string, string[]>> message, CT ct)
{
var queryString = message.Data.FromFullDictionary().ToQueryString();
return Task.FromResult(queryString);
}
public Task<Message<IDictionary<string, string[]>>> ReadAsync(string id)
public Task<Message<IDictionary<string, string[]>>> ReadAsync(string id, CT ct)
{
var values = id.ReadQueryStringAsNameValueCollection();
var msg = new Message<IDictionary<string, string[]>>(values.ToFullDictionary());
return Task.FromResult(msg);
}
public Task DeleteAsync(string id) => Task.CompletedTask;
public Task DeleteAsync(string id, CT ct) => Task.CompletedTask;
}

View file

@ -15,20 +15,23 @@ public interface IAuthorizationParametersMessageStore
/// Writes the authorization parameters.
/// </summary>
/// <param name="message">The message.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>The identifier for the stored message.</returns>
Task<string> WriteAsync(Message<IDictionary<string, string[]>> message);
Task<string> WriteAsync(Message<IDictionary<string, string[]>> message, CT ct);
/// <summary>
/// Reads the authorization parameters.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<Message<IDictionary<string, string[]>>> ReadAsync(string id);
Task<Message<IDictionary<string, string[]>>> ReadAsync(string id, CT ct);
/// <summary>
/// Deletes the authorization parameters.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task DeleteAsync(string id);
Task DeleteAsync(string id, CT ct);
}

View file

@ -16,19 +16,22 @@ public interface IConsentMessageStore
/// </summary>
/// <param name="id">The id for the message.</param>
/// <param name="message">The message.</param>
Task WriteAsync(string id, Message<ConsentResponse> message);
/// <param name="ct">The cancellation token.</param>
Task WriteAsync(string id, Message<ConsentResponse> message, CT ct);
/// <summary>
/// Reads the consent response message.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<Message<ConsentResponse>> ReadAsync(string id);
Task<Message<ConsentResponse>> ReadAsync(string id, CT ct);
/// <summary>
/// Deletes the consent response message.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task DeleteAsync(string id);
Task DeleteAsync(string id, CT ct);
}

View file

@ -16,13 +16,15 @@ public interface IMessageStore<TModel>
/// Writes the message.
/// </summary>
/// <param name="message">The message.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns>An identifier for the message</returns>
Task<string> WriteAsync(Message<TModel> message);
Task<string> WriteAsync(Message<TModel> message, CT ct);
/// <summary>
/// Reads the message.
/// </summary>
/// <param name="id">The identifier.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<Message<TModel>> ReadAsync(string id);
Task<Message<TModel>> ReadAsync(string id, CT ct);
}

View file

@ -230,7 +230,7 @@ public class EndSessionRequestValidator : IEndSessionRequestValidator
};
var endSessionId = parameters[Constants.UIConstants.DefaultRoutePathParams.EndSessionCallback];
var endSessionMessage = await EndSessionMessageStore.ReadAsync(endSessionId);
var endSessionMessage = await EndSessionMessageStore.ReadAsync(endSessionId, ct);
if (endSessionMessage?.Data?.ClientIds?.Any() == true)
{
result.IsError = false;

View file

@ -11,7 +11,7 @@ public class MockConsentMessageStore : IConsentMessageStore
{
public Dictionary<string, Message<ConsentResponse>> Messages { get; set; } = new Dictionary<string, Message<ConsentResponse>>();
public Task DeleteAsync(string id)
public Task DeleteAsync(string id, CT ct)
{
if (id != null && Messages.ContainsKey(id))
{
@ -20,7 +20,7 @@ public class MockConsentMessageStore : IConsentMessageStore
return Task.CompletedTask;
}
public Task<Message<ConsentResponse>> ReadAsync(string id)
public Task<Message<ConsentResponse>> ReadAsync(string id, CT ct)
{
Message<ConsentResponse> val = null;
if (id != null)
@ -30,7 +30,7 @@ public class MockConsentMessageStore : IConsentMessageStore
return Task.FromResult(val);
}
public Task WriteAsync(string id, Message<ConsentResponse> message)
public Task WriteAsync(string id, Message<ConsentResponse> message, CT ct)
{
Messages[id] = message;
return Task.CompletedTask;

View file

@ -11,7 +11,7 @@ public class MockMessageStore<TModel> : IMessageStore<TModel>
{
public Dictionary<string, Message<TModel>> Messages { get; set; } = new Dictionary<string, Message<TModel>>();
public Task<Message<TModel>> ReadAsync(string id)
public Task<Message<TModel>> ReadAsync(string id, CT ct)
{
Message<TModel> val = null;
if (id != null)
@ -21,7 +21,7 @@ public class MockMessageStore<TModel> : IMessageStore<TModel>
return Task.FromResult(val);
}
public Task<string> WriteAsync(Message<TModel> message)
public Task<string> WriteAsync(Message<TModel> message, CT ct)
{
var id = Guid.NewGuid().ToString();
Messages[id] = message;

View file

@ -23,6 +23,7 @@ namespace UnitTests.Endpoints.Results;
public class AuthorizeResultTests
{
private AuthorizeHttpWriter _subject;
private readonly CT _ct = TestContext.Current.CancellationToken;
private AuthorizeResponse _response = new AuthorizeResponse();
private IdentityServerOptions _options = new IdentityServerOptions();
@ -351,7 +352,7 @@ public class AuthorizeResultTests
var queryString = new Uri(location).Query;
var queryParams = QueryHelpers.ParseQuery(queryString);
var errorId = queryParams.First(kvp => kvp.Key == _options.UserInteraction.ErrorIdParameter).Value;
var errorMessage = await _mockErrorMessageStore.ReadAsync(errorId);
var errorMessage = await _mockErrorMessageStore.ReadAsync(errorId, _ct);
errorMessage.Data.RedirectUri.ShouldBeNull();
errorMessage.Data.ResponseMode.ShouldBeNull();
}

View file

@ -13,6 +13,7 @@ public class DistributedCacheAuthorizationParametersMessageStoreTests
{
private MockDistributedCache _mockCache = new MockDistributedCache();
private DistributedCacheAuthorizationParametersMessageStore _subject;
private readonly CT _ct = TestContext.Current.CancellationToken;
public DistributedCacheAuthorizationParametersMessageStoreTests() => _subject = new DistributedCacheAuthorizationParametersMessageStore(_mockCache, new DefaultHandleGenerationService());
[Fact]
@ -21,11 +22,11 @@ public class DistributedCacheAuthorizationParametersMessageStoreTests
_mockCache.Items.Count.ShouldBe(0);
var msg = new Message<IDictionary<string, string[]>>(new Dictionary<string, string[]>());
var id = await _subject.WriteAsync(msg);
var id = await _subject.WriteAsync(msg, _ct);
_mockCache.Items.Count.ShouldBe(1);
await _subject.DeleteAsync(id);
await _subject.DeleteAsync(id, _ct);
_mockCache.Items.Count.ShouldBe(0);
}