Make CT required in IUserInfoResponseGenerator, ISessionCoordinationService.ValidateSessionAsync, and IReturnUrlParser, flow through implementations, callers, and tests

This commit is contained in:
Damian Hickey 2026-02-20 22:45:10 +01:00
parent a5e42b70ab
commit f7d6f09c4e
15 changed files with 44 additions and 31 deletions

View file

@ -89,7 +89,7 @@ internal class UserInfoEndpoint : IEndpointHandler
// generate response
_logger.LogTrace("Calling into userinfo response generator: {type}", _responseGenerator.GetType().FullName);
var response = await _responseGenerator.ProcessAsync(validationResult);
var response = await _responseGenerator.ProcessAsync(validationResult, context.RequestAborted);
_logger.LogDebug("End userinfo request");
return new UserInfoResult(response);

View file

@ -51,9 +51,10 @@ public class UserInfoResponseGenerator : IUserInfoResponseGenerator
/// Creates the response.
/// </summary>
/// <param name="validationResult">The userinfo request validation result.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
/// <exception cref="System.InvalidOperationException">Profile service returned incorrect subject value</exception>
public virtual async Task<Dictionary<string, object>> ProcessAsync(UserInfoRequestValidationResult validationResult)
public virtual async Task<Dictionary<string, object>> ProcessAsync(UserInfoRequestValidationResult validationResult, CT ct)
{
using var activity = Tracing.BasicActivitySource.StartActivity("UserInfoResponseGenerator.Process");
@ -62,7 +63,7 @@ public class UserInfoResponseGenerator : IUserInfoResponseGenerator
// extract scopes and turn into requested claim types
var scopes = validationResult.TokenValidationResult.Claims.Where(c => c.Type == JwtClaimTypes.Scope).Select(c => c.Value);
var validatedResources = await GetRequestedResourcesAsync(scopes);
var validatedResources = await GetRequestedResourcesAsync(scopes, ct);
var requestedClaimTypes = await GetRequestedClaimTypesAsync(validatedResources);
Logger.LogDebug("Requested claim types: {claimTypes}", requestedClaimTypes.ToSpaceSeparatedString());
@ -75,7 +76,7 @@ public class UserInfoResponseGenerator : IUserInfoResponseGenerator
requestedClaimTypes);
context.RequestedResources = validatedResources;
await Profile.GetProfileDataAsync(context, default);
await Profile.GetProfileDataAsync(context, ct);
var profileClaims = context.IssuedClaims;
// construct outgoing claims
@ -109,8 +110,9 @@ public class UserInfoResponseGenerator : IUserInfoResponseGenerator
/// Gets the identity resources from the scopes.
/// </summary>
/// <param name="scopes"></param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
protected internal virtual async Task<ResourceValidationResult> GetRequestedResourcesAsync(IEnumerable<string> scopes)
protected internal virtual async Task<ResourceValidationResult> GetRequestedResourcesAsync(IEnumerable<string> scopes, CT ct)
{
if (scopes == null || !scopes.Any())
{
@ -121,7 +123,7 @@ public class UserInfoResponseGenerator : IUserInfoResponseGenerator
Logger.LogDebug("Scopes in access token: {scopes}", scopeString);
// if we ever parameterized identity scopes, then we would need to invoke the resource validator's parse API here
var identityResources = await Resources.FindEnabledIdentityResourcesByScopeAsync(scopes, default);
var identityResources = await Resources.FindEnabledIdentityResourcesByScopeAsync(scopes, ct);
var resources = new Resources(identityResources, Enumerable.Empty<ApiResource>(), Enumerable.Empty<ApiScope>());
var result = new ResourceValidationResult(resources);

View file

@ -15,6 +15,7 @@ public interface IUserInfoResponseGenerator
/// Creates the response.
/// </summary>
/// <param name="validationResult">The userinfo request validation result.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<Dictionary<string, object>> ProcessAsync(UserInfoRequestValidationResult validationResult);
Task<Dictionary<string, object>> ProcessAsync(UserInfoRequestValidationResult validationResult, CT ct);
}

View file

@ -48,7 +48,7 @@ internal class DefaultIdentityServerInteractionService : IIdentityServerInteract
{
using var activity = Tracing.ServiceActivitySource.StartActivity("DefaultIdentityServerInteractionService.GetAuthorizationContext");
var result = await _returnUrlParser.ParseAsync(returnUrl);
var result = await _returnUrlParser.ParseAsync(returnUrl, default);
if (result != null)
{

View file

@ -194,7 +194,7 @@ public class DefaultSessionCoordinationService : ISessionCoordinationService
/// <inheritdoc/>
public virtual async Task<bool> ValidateSessionAsync(SessionValidationRequest request)
public virtual async Task<bool> ValidateSessionAsync(SessionValidationRequest request, CT ct)
{
if (ServerSideSessionStore != null)
{
@ -208,7 +208,7 @@ public class DefaultSessionCoordinationService : ISessionCoordinationService
{
SubjectId = request.SubjectId,
SessionId = request.SessionId
}, default);
}, ct);
var valid = sessions.Count > 0 &&
sessions.Any(x => x.Expires == null || DateTime.UtcNow < x.Expires.Value);
@ -238,6 +238,7 @@ public class DefaultSessionCoordinationService : ISessionCoordinationService
//result in the cookie never being renewed and expiring in a surprising way. Renewing
//the ticket also updates the session, so we don't need to do both.
if (Options.Authentication.CookieSlidingExpiration &&
#pragma warning disable CA2016 // ITicketStore interface has no CT parameter
await ServerSideTicketStore.RetrieveAsync(session.Key) is
{ Properties: { IsPersistent: true, AllowRefresh: null or true } } ticket)
{
@ -245,10 +246,11 @@ public class DefaultSessionCoordinationService : ISessionCoordinationService
ticket.Properties.IssuedUtc = session.Renewed;
ticket.Properties.ExpiresUtc = session.Expires;
await ServerSideTicketStore.RenewAsync(session.Key, ticket);
#pragma warning restore CA2016
}
else
{
await ServerSideSessionStore.UpdateSessionAsync(session, default);
await ServerSideSessionStore.UpdateSessionAsync(session, ct);
}
}
}

View file

@ -38,7 +38,7 @@ internal class OidcReturnUrlParser : IReturnUrlParser
_authorizationParametersMessageStore = authorizationParametersMessageStore;
}
public async Task<AuthorizationRequest> ParseAsync(string returnUrl)
public async Task<AuthorizationRequest> ParseAsync(string returnUrl, CT ct)
{
using var activity = Tracing.ValidationActivitySource.StartActivity("OidcReturnUrlParser.Parse");
@ -48,11 +48,11 @@ internal class OidcReturnUrlParser : IReturnUrlParser
if (_authorizationParametersMessageStore != null)
{
var messageStoreId = parameters[Constants.AuthorizationParamsStore.MessageStoreIdParameterName];
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId, default);
var entry = await _authorizationParametersMessageStore.ReadAsync(messageStoreId, ct);
parameters = entry?.Data.FromFullDictionary() ?? new NameValueCollection();
}
var user = await _userSession.GetUserAsync(default);
var user = await _userSession.GetUserAsync(ct);
var result = await _validator.ValidateAsync(parameters, user);
if (!result.IsError)
{

View file

@ -23,14 +23,15 @@ public class ReturnUrlParser
/// Parses the return URL.
/// </summary>
/// <param name="returnUrl">The return URL.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
public virtual async Task<AuthorizationRequest> ParseAsync(string returnUrl)
public virtual async Task<AuthorizationRequest> ParseAsync(string returnUrl, CT ct)
{
using var activity = Tracing.ValidationActivitySource.StartActivity("ReturnUrlParser.Parse");
foreach (var parser in _parsers)
{
var result = await parser.ParseAsync(returnUrl);
var result = await parser.ParseAsync(returnUrl, ct);
if (result != null)
{
return result;

View file

@ -57,7 +57,7 @@ internal class ServerSideSessionRefreshTokenService : IRefreshTokenService
SessionId = result.RefreshToken.SessionId,
Client = result.Client,
Type = SessionValidationType.RefreshToken
});
}, ct);
if (!valid)
{

View file

@ -17,8 +17,9 @@ public interface IReturnUrlParser
/// Parses a return URL.
/// </summary>
/// <param name="returnUrl">The return URL.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<AuthorizationRequest?> ParseAsync(string returnUrl);
Task<AuthorizationRequest?> ParseAsync(string returnUrl, CT ct);
/// <summary>
/// Determines whether the return URL is valid.

View file

@ -27,7 +27,9 @@ public interface ISessionCoordinationService
/// Validates client request, and if valid extends server-side session.
/// Returns false if the session is invalid, true otherwise.
/// </summary>
Task<bool> ValidateSessionAsync(SessionValidationRequest request);
/// <param name="request">The session validation request.</param>
/// <param name="ct">The cancellation token.</param>
Task<bool> ValidateSessionAsync(SessionValidationRequest request, CT ct);
}
/// <summary>

View file

@ -230,7 +230,7 @@ internal class TokenValidator : ITokenValidator
SessionId = sid,
Client = result.Client,
Type = SessionValidationType.AccessToken
});
}, ct);
if (!sessionResult)
{

View file

@ -16,7 +16,7 @@ public class MockReturnUrlParser : ReturnUrlParser
{
}
public override Task<AuthorizationRequest> ParseAsync(string returnUrl) => Task.FromResult(AuthorizationRequestResult);
public override Task<AuthorizationRequest> ParseAsync(string returnUrl, CT ct) => Task.FromResult(AuthorizationRequestResult);
public override bool IsValidReturnUrl(string returnUrl) => IsValidReturnUrlResult;
}

View file

@ -13,5 +13,5 @@ internal class StubSessionCoordinationService : ISessionCoordinationService
public Task ProcessLogoutAsync(UserSession session, CT _) => Task.CompletedTask;
public Task<bool> ValidateSessionAsync(SessionValidationRequest request) => Task.FromResult(true);
public Task<bool> ValidateSessionAsync(SessionValidationRequest request, CT _) => Task.FromResult(true);
}

View file

@ -15,6 +15,8 @@ namespace UnitTests.ResponseHandling;
public class UserInfoResponseGeneratorTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
private UserInfoResponseGenerator _subject;
private MockProfileService _mockProfileService = new MockProfileService();
private ClaimsPrincipal _user;
@ -50,7 +52,7 @@ public class UserInfoResponseGeneratorTests
[Fact]
public async Task GetRequestedClaimTypesAsync_when_no_scopes_requested_should_return_empty_claim_types()
{
var resources = await _subject.GetRequestedResourcesAsync(null);
var resources = await _subject.GetRequestedResourcesAsync(null, _ct);
var claims = await _subject.GetRequestedClaimTypesAsync(resources);
claims.ShouldBe(new string[] { });
}
@ -61,7 +63,7 @@ public class UserInfoResponseGeneratorTests
_identityResources.Add(new IdentityResource("id1", new[] { "c1", "c2" }));
_identityResources.Add(new IdentityResource("id2", new[] { "c2", "c3" }));
var resources = await _subject.GetRequestedResourcesAsync(new[] { "id1", "id2", "id3" });
var resources = await _subject.GetRequestedResourcesAsync(new[] { "id1", "id2", "id3" }, _ct);
var claims = await _subject.GetRequestedClaimTypesAsync(resources);
claims.ShouldBe(["c1", "c2", "c3"]);
}
@ -72,7 +74,7 @@ public class UserInfoResponseGeneratorTests
_identityResources.Add(new IdentityResource("id1", new[] { "c1", "c2" }) { Enabled = false });
_identityResources.Add(new IdentityResource("id2", new[] { "c2", "c3" }));
var resources = await _subject.GetRequestedResourcesAsync(new[] { "id1", "id2", "id3" });
var resources = await _subject.GetRequestedResourcesAsync(new[] { "id1", "id2", "id3" }, _ct);
var claims = await _subject.GetRequestedClaimTypesAsync(resources);
claims.ShouldBe(["c2", "c3"]);
}
@ -98,7 +100,7 @@ public class UserInfoResponseGeneratorTests
}
};
var claims = await _subject.ProcessAsync(result);
var claims = await _subject.ProcessAsync(result, _ct);
_mockProfileService.GetProfileWasCalled.ShouldBeTrue();
_mockProfileService.ProfileContext.RequestedClaimTypes.ShouldBe(["foo", "bar"]);
@ -141,7 +143,7 @@ public class UserInfoResponseGeneratorTests
}
};
var claims = await _subject.ProcessAsync(result);
var claims = await _subject.ProcessAsync(result, _ct);
claims.ShouldContainKey("email");
claims["email"].ShouldBe("fred@gmail.com");
@ -178,7 +180,7 @@ public class UserInfoResponseGeneratorTests
}
};
var claims = await _subject.ProcessAsync(result);
var claims = await _subject.ProcessAsync(result, _ct);
claims.ShouldContainKey("sub");
claims["sub"].ShouldBe("bob");
@ -209,7 +211,7 @@ public class UserInfoResponseGeneratorTests
}
};
Func<Task> act = () => _subject.ProcessAsync(result);
Func<Task> act = () => _subject.ProcessAsync(result, _ct);
var exception = await act.ShouldThrowAsync<InvalidOperationException>();
exception.Message.ShouldMatch(".*subject.*");

View file

@ -13,6 +13,8 @@ namespace UnitTests.Validation;
public class IsLocalUrlTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
private const string queryParameters = "?client_id=mvc.code" +
"&redirect_uri=https%3A%2F%2Flocalhost%3A44302%2Fsignin-oidc" +
"&response_type=code" +
@ -105,7 +107,7 @@ public class IsLocalUrlTests
public async Task OidcReturnUrlParser_ParseAsync(string returnUrl, bool expected)
{
var oidcParser = GetOidcReturnUrlParser();
var actual = await oidcParser.ParseAsync(returnUrl);
var actual = await oidcParser.ParseAsync(returnUrl, _ct);
if (expected)
{
actual.ShouldNotBeNull();
@ -138,7 +140,7 @@ public class IsLocalUrlTests
public async Task ReturnUrlParser_ParseAsync(string returnUrl, bool expected)
{
var parser = GetReturnUrlParser();
var actual = await parser.ParseAsync(returnUrl);
var actual = await parser.ParseAsync(returnUrl, _ct);
if (expected)
{
actual.ShouldNotBeNull();