Make CT required in IIdentityProviderStore, IServerSideTicketStore, ITokenCleanupService, IOperationalStoreNotification, DiagnosticDataService, DiagnosticSummary, KeyManager internals, and related implementations and tests

This commit is contained in:
Damian Hickey 2026-02-21 10:55:03 +01:00
parent 47d5e8b47f
commit 04f6388b40
35 changed files with 122 additions and 108 deletions

View file

@ -5,5 +5,5 @@ namespace Duende.ConformanceReport;
internal interface IConformanceReportClientStore
{
Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CT ct = default);
Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CT ct);
}

View file

@ -99,7 +99,7 @@ public class ConformanceReportEndpointTests
private sealed class InMemoryClientStore(IEnumerable<ConformanceReportClient> clients) : IConformanceReportClientStore
{
public Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CancellationToken ct = default)
public Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CancellationToken ct)
=> Task.FromResult(clients);
}

View file

@ -104,7 +104,7 @@ public class ConformanceAssessmentServiceTests
private sealed class InMemoryClientStore(IEnumerable<ConformanceReportClient> clients) : IConformanceReportClientStore
{
public Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CancellationToken ct = default) => Task.FromResult(clients);
public Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(CancellationToken ct) => Task.FromResult(clients);
}
private sealed class TestHttpContextAccessor : IHttpContextAccessor

View file

@ -12,7 +12,7 @@ public class TestOperationalStoreNotification : IOperationalStoreNotification
{
public TestOperationalStoreNotification() => Console.WriteLine("ctor");
public Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct = default)
public Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct)
{
ArgumentNullException.ThrowIfNull(persistedGrants);
foreach (var grant in persistedGrants)
@ -22,7 +22,7 @@ public class TestOperationalStoreNotification : IOperationalStoreNotification
return Task.CompletedTask;
}
public Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct = default)
public Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct)
{
ArgumentNullException.ThrowIfNull(deviceCodes);
foreach (var deviceCode in deviceCodes)

View file

@ -50,7 +50,7 @@ public class Index : PageModel
public async Task<IActionResult> OnGet(string? returnUrl)
{
await BuildModelAsync(returnUrl);
await BuildModelAsync(returnUrl, HttpContext.RequestAborted);
if (View.IsExternalLoginOnly)
{
@ -147,11 +147,11 @@ public class Index : PageModel
}
// something went wrong, show form with error
await BuildModelAsync(Input.ReturnUrl);
await BuildModelAsync(Input.ReturnUrl, HttpContext.RequestAborted);
return Page();
}
private async Task BuildModelAsync(string? returnUrl)
private async Task BuildModelAsync(string? returnUrl, CT ct)
{
Input = new InputModel
{
@ -193,7 +193,7 @@ public class Index : PageModel
displayName: x.DisplayName ?? x.Name
)).ToList();
var dynamicSchemes = (await _identityProviderStore.GetAllSchemeNamesAsync())
var dynamicSchemes = (await _identityProviderStore.GetAllSchemeNamesAsync(ct))
.Where(x => x.Enabled)
.Select(x => new ViewModel.ExternalProvider
(

View file

@ -46,7 +46,7 @@ public class Index : PageModel
public async Task<IActionResult> OnGet(string? returnUrl)
{
await BuildModelAsync(returnUrl);
await BuildModelAsync(returnUrl, HttpContext.RequestAborted);
if (View.IsExternalLoginOnly)
{
@ -157,11 +157,11 @@ public class Index : PageModel
}
// something went wrong, show form with error
await BuildModelAsync(Input.ReturnUrl);
await BuildModelAsync(Input.ReturnUrl, HttpContext.RequestAborted);
return Page();
}
private async Task BuildModelAsync(string? returnUrl)
private async Task BuildModelAsync(string? returnUrl, CT ct)
{
Input = new InputModel
{
@ -203,7 +203,7 @@ public class Index : PageModel
displayName: x.DisplayName ?? x.Name
)).ToList();
var dynamicSchemes = (await _identityProviderStore.GetAllSchemeNamesAsync())
var dynamicSchemes = (await _identityProviderStore.GetAllSchemeNamesAsync(ct))
.Where(x => x.Enabled)
.Select(x => new ViewModel.ExternalProvider
(

View file

@ -18,7 +18,7 @@ public static class DbContextExtensions
/// <summary>
/// Saves changes and handles concurrency exceptions.
/// </summary>
public static async Task<ICollection<T>> SaveChangesWithConcurrencyCheckAsync<T>(this IPersistedGrantDbContext context, ILogger logger, CT ct = default)
public static async Task<ICollection<T>> SaveChangesWithConcurrencyCheckAsync<T>(this IPersistedGrantDbContext context, ILogger logger, CT ct)
where T : class
{
var list = new List<T>();

View file

@ -48,7 +48,7 @@ public class IdentityProviderStore : IIdentityProviderStore
}
/// <inheritdoc/>
public async Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default)
public async Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("IdentityProviderStore.GetAllSchemeNames");
@ -63,7 +63,7 @@ public class IdentityProviderStore : IIdentityProviderStore
}
/// <inheritdoc/>
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct = default)
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("IdentityProviderStore.GetByScheme");
activity?.SetTag(Tracing.Properties.Scheme, scheme);

View file

@ -19,13 +19,13 @@ public interface IOperationalStoreNotification
/// <param name="persistedGrants"></param>
/// <param name="ct"></param>
/// <returns></returns>
Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct = default);
Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct);
/// <summary>
/// Notification for device codes being removed.
/// </summary>
/// <param name="deviceCodes"></param>
/// <param name="ct"></param>
/// <param name="deviceCodes">The device codes being removed.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct = default);
Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct);
}

View file

@ -19,5 +19,5 @@ public interface ITokenCleanupService
/// <param name="ct">A token that propagates notification
/// that the cleanup operation should be canceled.</param>
/// <returns></returns>
Task CleanupGrantsAsync(CT ct = default);
Task CleanupGrantsAsync(CT ct);
}

View file

@ -43,7 +43,7 @@ public class TokenCleanupService : ITokenCleanupService
}
/// <inheritdoc/>
public async Task CleanupGrantsAsync(CT ct = default)
public async Task CleanupGrantsAsync(CT ct)
{
try
{
@ -63,7 +63,7 @@ public class TokenCleanupService : ITokenCleanupService
/// Removes the stale persisted grants.
/// </summary>
/// <returns></returns>
protected virtual async Task RemoveGrantsAsync(CT ct = default)
protected virtual async Task RemoveGrantsAsync(CT ct)
{
await RemoveExpiredPersistedGrantsAsync(ct);
if (_options.RemoveConsumedTokens)
@ -76,7 +76,7 @@ public class TokenCleanupService : ITokenCleanupService
/// Removes the expired persisted grants.
/// </summary>
/// <returns></returns>
protected virtual async Task RemoveExpiredPersistedGrantsAsync(CT ct = default)
protected virtual async Task RemoveExpiredPersistedGrantsAsync(CT ct)
{
var found = int.MaxValue;
@ -145,7 +145,7 @@ public class TokenCleanupService : ITokenCleanupService
/// Removes the consumed persisted grants.
/// </summary>
/// <returns></returns>
protected virtual async Task RemoveConsumedPersistedGrantsAsync(CT ct = default)
protected virtual async Task RemoveConsumedPersistedGrantsAsync(CT ct)
{
var found = int.MaxValue;
@ -208,7 +208,7 @@ public class TokenCleanupService : ITokenCleanupService
/// Removes the stale device codes.
/// </summary>
/// <returns></returns>
protected virtual async Task RemoveDeviceCodesAsync(CT ct = default)
protected virtual async Task RemoveDeviceCodesAsync(CT ct)
{
var found = int.MaxValue;
@ -264,7 +264,7 @@ public class TokenCleanupService : ITokenCleanupService
/// <summary>
/// Removes stale pushed authorization requests.
/// </summary>
protected virtual async Task RemovePushedAuthorizationRequestsAsync(CT ct = default)
protected virtual async Task RemovePushedAuthorizationRequestsAsync(CT ct)
{
var found = int.MaxValue;

View file

@ -121,7 +121,7 @@ public class TokenCleanupHost : IHostedService
}
}
private async Task RemoveExpiredGrantsAsync(CT ct = default)
private async Task RemoveExpiredGrantsAsync(CT ct)
{
try
{

View file

@ -14,7 +14,7 @@ internal sealed class IdentityServerClientStore(IClientStore clientStore) : ICon
#pragma warning restore CA1812
{
public async Task<IEnumerable<ConformanceReportClient>> GetAllClientsAsync(
CancellationToken ct = default)
CancellationToken ct)
{
var clients = new List<ConformanceReportClient>();
await foreach (var client in clientStore.GetAllClientsAsync(ct))

View file

@ -13,7 +13,7 @@ internal class InMemoryIdentityProviderStore : IIdentityProviderStore
public InMemoryIdentityProviderStore(IEnumerable<IdentityProvider> providers) => _providers = providers;
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default)
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("InMemoryOidcProviderStore.GetAllSchemeNames");
@ -27,7 +27,7 @@ internal class InMemoryIdentityProviderStore : IIdentityProviderStore
return Task.FromResult(items);
}
public Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct = default)
public Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("InMemoryOidcProviderStore.GetByScheme");

View file

@ -38,10 +38,10 @@ public class NonCachingIdentityProviderStore<T> : IIdentityProviderStore
}
/// <inheritdoc/>
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default) => _inner.GetAllSchemeNamesAsync(ct);
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct) => _inner.GetAllSchemeNamesAsync(ct);
/// <inheritdoc/>
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct = default)
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct)
{
if (_httpContextAccessor.HttpContext == null)
{

View file

@ -9,7 +9,7 @@ namespace Duende.IdentityServer.Hosting.DynamicProviders;
internal class NopIdentityProviderStore : IIdentityProviderStore
{
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default) => Task.FromResult(Enumerable.Empty<IdentityProviderName>());
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct) => Task.FromResult(Enumerable.Empty<IdentityProviderName>());
public Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct = default) => Task.FromResult<IdentityProvider>(null);
public Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct) => Task.FromResult<IdentityProvider>(null);
}

View file

@ -38,10 +38,10 @@ public class ValidatingIdentityProviderStore<T> : IIdentityProviderStore
}
/// <inheritdoc/>
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default) => _inner.GetAllSchemeNamesAsync(ct);
public Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct) => _inner.GetAllSchemeNamesAsync(ct);
/// <inheritdoc/>
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct = default)
public async Task<IdentityProvider> GetBySchemeAsync(string scheme, CT ct)
{
var idp = await _inner.GetBySchemeAsync(scheme, ct);

View file

@ -68,7 +68,7 @@ public class ServerSideSessionCleanupHost(
logger.LogDebug("Stopping server-side session removal");
}
private async Task RunAsync(CT ct = default)
private async Task RunAsync(CT ct)
{
// this is here for testing
if (!options.ServerSideSessions.RemoveExpiredSessions)

View file

@ -19,7 +19,7 @@ internal class DiagnosticHostedService(DiagnosticSummary diagnosticSummary, IOpt
{
try
{
await diagnosticSummary.PrintSummary();
await diagnosticSummary.PrintSummary(stoppingToken);
}
catch (Exception ex)
{
@ -39,7 +39,7 @@ internal class DiagnosticHostedService(DiagnosticSummary diagnosticSummary, IOpt
public override async Task StopAsync(CT ct)
{
await diagnosticSummary.PrintSummary();
await diagnosticSummary.PrintSummary(ct);
await base.StopAsync(ct);
}

View file

@ -12,9 +12,9 @@ internal class DiagnosticSummary(DiagnosticDataService diagnosticDataService, Id
{
private readonly ILogger _logger = loggerFactory.CreateLogger("Duende.IdentityServer.Diagnostics.Summary");
public async Task PrintSummary()
public async Task PrintSummary(CT ct)
{
var jsonMemory = await diagnosticDataService.GetJsonBytesAsync();
var jsonMemory = await diagnosticDataService.GetJsonBytesAsync(ct);
var span = jsonMemory.Span;
using var diagnosticActivity = Tracing.DiagnosticsActivitySource.StartActivity("DiagnosticSummary");

View file

@ -95,7 +95,7 @@ public class KeyManager : IKeyManager
internal async Task<(IEnumerable<KeyContainer> allKeys, IEnumerable<KeyContainer> signingKeys)> GetAllKeysInternalAsync(CT ct = default)
internal async Task<(IEnumerable<KeyContainer> allKeys, IEnumerable<KeyContainer> signingKeys)> GetAllKeysInternalAsync(CT ct)
{
var cached = true;
var keys = await GetAllKeysFromCacheAsync(ct);
@ -265,7 +265,7 @@ public class KeyManager : IKeyManager
return false;
}
internal async Task<KeyContainer> CreateAndStoreNewKeyAsync(SigningAlgorithmOptions alg, CT ct = default)
internal async Task<KeyContainer> CreateAndStoreNewKeyAsync(SigningAlgorithmOptions alg, CT ct)
{
_logger.LogTrace("Creating new key.");
@ -307,7 +307,7 @@ public class KeyManager : IKeyManager
return container;
}
internal async Task<IEnumerable<KeyContainer>> GetAllKeysFromCacheAsync(CT ct = default)
internal async Task<IEnumerable<KeyContainer>> GetAllKeysFromCacheAsync(CT ct)
{
var cachedKeys = await _cache.GetKeysAsync(ct);
if (cachedKeys != null)
@ -340,7 +340,7 @@ public class KeyManager : IKeyManager
return result;
}
internal async Task<IEnumerable<SerializedKey>> FilterAndDeleteRetiredKeysAsync(IEnumerable<SerializedKey> keys, CT ct = default)
internal async Task<IEnumerable<SerializedKey>> FilterAndDeleteRetiredKeysAsync(IEnumerable<SerializedKey> keys, CT ct)
{
var retired = keys
.Where(x =>
@ -373,7 +373,7 @@ public class KeyManager : IKeyManager
return result;
}
internal async Task DeleteKeysAsync(IEnumerable<string> keys, CT ct = default)
internal async Task DeleteKeysAsync(IEnumerable<string> keys, CT ct)
{
if (keys == null || !keys.Any())
{
@ -399,7 +399,7 @@ public class KeyManager : IKeyManager
return result;
}
internal async Task CacheKeysAsync(IEnumerable<KeyContainer> keys, CT ct = default)
internal async Task CacheKeysAsync(IEnumerable<KeyContainer> keys, CT ct)
{
if (keys?.Any() == true)
{
@ -505,7 +505,7 @@ public class KeyManager : IKeyManager
internal async Task<(IEnumerable<KeyContainer> allKeys, IEnumerable<KeyContainer> activeKeys)> CreateNewKeysAndAddToCacheAsync(CT ct = default)
internal async Task<(IEnumerable<KeyContainer> allKeys, IEnumerable<KeyContainer> activeKeys)> CreateNewKeysAndAddToCacheAsync(CT ct)
{
var keys = new List<KeyContainer>();
keys.AddRange(await _cache.GetKeysAsync(ct) ?? Enumerable.Empty<KeyContainer>());

View file

@ -22,7 +22,7 @@ public class DiagnosticDataService
_entries = entries;
}
public async Task<ReadOnlyMemory<byte>> GetJsonBytesAsync(CT ct = default)
public async Task<ReadOnlyMemory<byte>> GetJsonBytesAsync(CT ct)
{
var bufferWriter = new ArrayBufferWriter<byte>();
await using var writer = new Utf8JsonWriter(bufferWriter, new JsonWriterOptions { Indented = false });
@ -42,7 +42,7 @@ public class DiagnosticDataService
return bufferWriter.WrittenMemory;
}
public async Task<string> GetJsonStringAsync(CT ct = default)
public async Task<string> GetJsonStringAsync(CT ct)
{
var bytes = await GetJsonBytesAsync(ct);
return Encoding.UTF8.GetString(bytes.Span);

View file

@ -185,7 +185,7 @@ public class ServerSideTicketStore : IServerSideTicketStore
}
/// <inheritdoc/>
public async Task<IReadOnlyCollection<UserSession>> GetSessionsAsync(SessionFilter filter, CT ct = default)
public async Task<IReadOnlyCollection<UserSession>> GetSessionsAsync(SessionFilter filter, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ServerSideTicketStore.GetSessions");
@ -196,7 +196,7 @@ public class ServerSideTicketStore : IServerSideTicketStore
}
/// <inheritdoc />
public async Task<QueryResult<UserSession>> QuerySessionsAsync(SessionQuery filter = null, CT ct = default)
public async Task<QueryResult<UserSession>> QuerySessionsAsync(SessionQuery filter, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ServerSideTicketStore.QuerySessions");
@ -219,7 +219,7 @@ public class ServerSideTicketStore : IServerSideTicketStore
}
/// <inheritdoc/>
public async Task<IReadOnlyCollection<UserSession>> GetAndRemoveExpiredSessionsAsync(int count, CT ct = default)
public async Task<IReadOnlyCollection<UserSession>> GetAndRemoveExpiredSessionsAsync(int count, CT ct)
{
using var activity = Tracing.StoreActivitySource.StartActivity("ServerSideTicketStore.GetAndRemoveExpiredSessions");

View file

@ -17,15 +17,21 @@ public interface IServerSideTicketStore : ITicketStore
/// <summary>
/// Gets sessions for a specific subject id and/or session id
/// </summary>
Task<IReadOnlyCollection<UserSession>> GetSessionsAsync(SessionFilter filter, CT ct = default);
/// <param name="filter">The session filter.</param>
/// <param name="ct">The cancellation token.</param>
Task<IReadOnlyCollection<UserSession>> GetSessionsAsync(SessionFilter filter, CT ct);
/// <summary>
/// Queries user sessions based on filter
/// </summary>
Task<QueryResult<UserSession>> QuerySessionsAsync(SessionQuery filter, CT ct = default);
/// <param name="filter">The session query filter.</param>
/// <param name="ct">The cancellation token.</param>
Task<QueryResult<UserSession>> QuerySessionsAsync(SessionQuery filter, CT ct);
/// <summary>
/// Removes and returns expired sessions
/// </summary>
Task<IReadOnlyCollection<UserSession>> GetAndRemoveExpiredSessionsAsync(int count, CT ct = default);
/// <param name="count">The maximum number of sessions to return.</param>
/// <param name="ct">The cancellation token.</param>
Task<IReadOnlyCollection<UserSession>> GetAndRemoveExpiredSessionsAsync(int count, CT ct);
}

View file

@ -16,14 +16,14 @@ public interface IIdentityProviderStore
/// <summary>
/// Gets all identity providers name.
/// </summary>
/// <param name="ct"></param>
Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct = default);
/// <param name="ct">The cancellation token.</param>
Task<IEnumerable<IdentityProviderName>> GetAllSchemeNamesAsync(CT ct);
/// <summary>
/// Gets the identity provider by scheme name.
/// </summary>
/// <param name="scheme"></param>
/// <param name="ct"></param>
/// <param name="scheme">The scheme name.</param>
/// <param name="ct">The cancellation token.</param>
/// <returns></returns>
Task<IdentityProvider?> GetBySchemeAsync(string scheme, CT ct = default);
Task<IdentityProvider?> GetBySchemeAsync(string scheme, CT ct);
}

View file

@ -11,6 +11,7 @@ namespace Duende.IdentityServer.IntegrationTests.Configuration;
public class DynamicClientRegistrationTests : ConfigurationIntegrationTestBase
{
private readonly CT _ct = TestContext.Current.CancellationToken;
[Fact]
public async Task valid_request_creates_new_client()
{
@ -29,7 +30,7 @@ public class DynamicClientRegistrationTests : ConfigurationIntegrationTestBase
var response = await httpResponse.Content.ReadFromJsonAsync<DynamicClientRegistrationResponse>();
response.ShouldNotBeNull();
var newClient = await IdentityServerHost.GetClientAsync(response!.ClientId); // Not null already asserted
var newClient = await IdentityServerHost.GetClientAsync(response!.ClientId, _ct); // Not null already asserted
newClient.ShouldNotBeNull();
newClient.ClientId.ShouldBe(response.ClientId);
newClient.AllowedGrantTypes.ShouldBe(request.GrantTypes);

View file

@ -15,14 +15,14 @@ public class MockOperationalStoreNotification : IOperationalStoreNotification
public Action<IEnumerable<PersistedGrant>> OnPersistedGrantsRemoved = _ => { };
public Action<IEnumerable<DeviceFlowCodes>> OnDeviceFlowCodesRemoved = _ => { };
public Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct = default)
public Task PersistedGrantsRemovedAsync(IEnumerable<PersistedGrant> persistedGrants, CT ct)
{
OnPersistedGrantsRemoved(persistedGrants);
PersistedGrantNotifications.Add(persistedGrants);
return Task.CompletedTask;
}
public Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct = default)
public Task DeviceCodesRemovedAsync(IEnumerable<DeviceFlowCodes> deviceCodes, CT ct)
{
OnDeviceFlowCodesRemoved(deviceCodes);
DeviceFlowCodeNotifications.Append(deviceCodes);

View file

@ -15,6 +15,8 @@ namespace Duende.IdentityServer.IntegrationTests.EntityFramework.Storage.Stores;
public class IdentityProviderStoreTests : IntegrationTest<IdentityProviderStoreTests, ConfigurationDbContext, ConfigurationStoreOptions>
{
private readonly CT _ct = TestContext.Current.CancellationToken;
public IdentityProviderStoreTests(DatabaseProviderFixture<ConfigurationDbContext> fixture) : base(fixture)
{
foreach (var options in TestDatabaseProviders)
@ -43,7 +45,7 @@ public class IdentityProviderStoreTests : IntegrationTest<IdentityProviderStoreT
await using (var context = new ConfigurationDbContext(options))
{
var store = new IdentityProviderStore(context, new NullLogger<IdentityProviderStore>(), new NoneCancellationTokenProvider());
var item = await store.GetBySchemeAsync("scheme1");
var item = await store.GetBySchemeAsync("scheme1", _ct);
item.ShouldNotBeNull();
}
@ -67,7 +69,7 @@ public class IdentityProviderStoreTests : IntegrationTest<IdentityProviderStoreT
await using (var context = new ConfigurationDbContext(options))
{
var store = new IdentityProviderStore(context, new NullLogger<IdentityProviderStore>(), new NoneCancellationTokenProvider());
var item = await store.GetBySchemeAsync("scheme2");
var item = await store.GetBySchemeAsync("scheme2", _ct);
item.ShouldBeNull();
}
@ -90,7 +92,7 @@ public class IdentityProviderStoreTests : IntegrationTest<IdentityProviderStoreT
await using (var context = new ConfigurationDbContext(options))
{
var store = new IdentityProviderStore(context, new NullLogger<IdentityProviderStore>(), new NoneCancellationTokenProvider());
var item = await store.GetBySchemeAsync("scheme3");
var item = await store.GetBySchemeAsync("scheme3", _ct);
item.ShouldBeNull();
}

View file

@ -18,6 +18,7 @@ namespace Duende.IdentityServer.IntegrationTests.EntityFramework.Storage.TokenCl
public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGrantDbContext, OperationalStoreOptions>
{
private readonly CT _ct = TestContext.Current.CancellationToken;
public TokenCleanupTests(DatabaseProviderFixture<PersistedGrantDbContext> fixture) : base(fixture)
{
foreach (var options in TestDatabaseProviders)
@ -57,7 +58,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options).CleanupGrantsAsync();
await CreateSut(options).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -84,7 +85,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options).CleanupGrantsAsync();
await CreateSut(options).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -128,7 +129,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options).CleanupGrantsAsync();
await CreateSut(options).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -159,7 +160,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options).CleanupGrantsAsync();
await CreateSut(options).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -187,7 +188,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options).CleanupGrantsAsync();
await CreateSut(options).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -216,7 +217,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options, removeConsumedTokens: true).CleanupGrantsAsync();
await CreateSut(options, removeConsumedTokens: true).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -244,7 +245,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options, removeConsumedTokens: false).CleanupGrantsAsync();
await CreateSut(options, removeConsumedTokens: false).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{
@ -287,7 +288,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await CreateSut(options, svcs =>
{
svcs.AddSingleton<IOperationalStoreNotification>(mockNotifications);
}).CleanupGrantsAsync();
}).CleanupGrantsAsync(_ct);
// The right number of batches executed
mockNotifications.PersistedGrantNotifications.Count.ShouldBe(expectedPageCount);
@ -356,7 +357,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await CreateSut(options, svcs =>
{
svcs.AddSingleton<IOperationalStoreNotification>(mockNotifications);
}).CleanupGrantsAsync();
}).CleanupGrantsAsync(_ct);
// Each batch created an extra grant, so we do an extra batch to clean up
// the extras
@ -417,7 +418,7 @@ public class TokenCleanupTests : IntegrationTest<TokenCleanupTests, PersistedGra
await context.SaveChangesAsync();
}
await CreateSut(options, removeConsumedTokens: true, delay).CleanupGrantsAsync();
await CreateSut(options, removeConsumedTokens: true, delay).CleanupGrantsAsync(_ct);
await using (var context = new PersistedGrantDbContext(options))
{

View file

@ -238,7 +238,7 @@ public class ServerSideSessionTests
await _pipeline.LoginAsync("alice");
_pipeline.RemoveLoginCookie();
var tickets = await _ticketService.GetSessionsAsync(new SessionFilter { SubjectId = "alice" });
var tickets = await _ticketService.GetSessionsAsync(new SessionFilter { SubjectId = "alice" }, _ct);
var sessions = await _sessionStore.GetSessionsAsync(new SessionFilter { SubjectId = "alice" }, _ct);
tickets.Select(x => x.SessionId).ShouldBe(sessions.Select(x => x.SessionId));

View file

@ -60,7 +60,7 @@ public class IdentityServerHost : GenericHost
}
public async Task<Client> GetClientAsync(string clientId, CT ct = default)
public async Task<Client> GetClientAsync(string clientId, CT ct)
{
var store = Resolve<ClientStore>();
return await store.FindClientByIdAsync(clientId, ct);

View file

@ -13,6 +13,8 @@ namespace IdentityServer.UnitTests.Licensing.V2;
public class DiagnosticSummaryTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
[Fact]
public async Task PrintSummary_ShouldCallWriteAsyncOnEveryDiagnosticEntry()
{
@ -29,7 +31,7 @@ public class DiagnosticSummaryTests
var diagnosticService = new DiagnosticDataService(DateTime.UtcNow, entries);
var summary = new DiagnosticSummary(diagnosticService, new IdentityServerOptions(), new StubLoggerFactory(logger));
await summary.PrintSummary();
await summary.PrintSummary(_ct);
firstDiagnosticEntry.WasCalled.ShouldBeTrue();
secondDiagnosticEntry.WasCalled.ShouldBeTrue();
@ -47,7 +49,7 @@ public class DiagnosticSummaryTests
var diagnosticService = new DiagnosticDataService(DateTime.UtcNow, [diagnosticEntry]);
var summary = new DiagnosticSummary(diagnosticService, options, new StubLoggerFactory(logger));
await summary.PrintSummary();
await summary.PrintSummary(_ct);
var logSnapshot = logger.Collector.GetSnapshot().Select(x => x.Message);
logSnapshot.ShouldBe([
@ -68,7 +70,7 @@ public class DiagnosticSummaryTests
var summary = new DiagnosticSummary(diagnosticService, options, new StubLoggerFactory(logger));
await summary.PrintSummary();
await summary.PrintSummary(_ct);
var logSnapshot = logger.Collector.GetSnapshot().Select(x => x.Message);
logSnapshot.ShouldBe(["Diagnostic data (1 of 3): {\"test\":", "Diagnostic data (2 of 3): \"\\u20AC\\", "Diagnostic data (3 of 3): u20AC\"}"]);
@ -85,7 +87,7 @@ public class DiagnosticSummaryTests
var summary = new DiagnosticSummary(diagnosticService, options, new StubLoggerFactory(logger));
await summary.PrintSummary();
await summary.PrintSummary(_ct);
foreach (var entry in logger.Collector.GetSnapshot())
{
entry.Message.Length.ShouldBeLessThanOrEqualTo(1024 * 8);
@ -101,7 +103,7 @@ public class DiagnosticSummaryTests
var diagnosticService = new DiagnosticDataService(DateTime.UtcNow, [diagnosticEntry]);
var summary = new DiagnosticSummary(diagnosticService, options, new StubLoggerFactory(logger));
await summary.PrintSummary();
await summary.PrintSummary(_ct);
var logSnapshot = logger.Collector.GetSnapshot();
logSnapshot.Count.ShouldBeGreaterThan(0);

View file

@ -314,7 +314,7 @@ public class DefaultPersistedGrantServiceTests
RequestedScopes = new string[] { "quux3" }
}, _ct);
await _subject.RemoveAllGrantsAsync("123", "client1");
await _subject.RemoveAllGrantsAsync("123", "client1", ct: _ct);
(await _referenceTokens.GetReferenceTokenAsync(handle1, _ct)).ShouldBeNull();
(await _referenceTokens.GetReferenceTokenAsync(handle2, _ct)).ShouldNotBeNull();
@ -358,7 +358,7 @@ public class DefaultPersistedGrantServiceTests
Lifetime = 10,
}, _ct);
await _subject.RemoveAllGrantsAsync("123");
await _subject.RemoveAllGrantsAsync("123", ct: _ct);
(await _refreshTokens.GetRefreshTokenAsync(handle1, _ct)).ShouldBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle2, _ct)).ShouldBeNull();
@ -396,7 +396,7 @@ public class DefaultPersistedGrantServiceTests
Lifetime = 10,
}, _ct);
await _subject.RemoveAllGrantsAsync("123", "client1");
await _subject.RemoveAllGrantsAsync("123", "client1", ct: _ct);
(await _refreshTokens.GetRefreshTokenAsync(handle1, _ct)).ShouldBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle2, _ct)).ShouldNotBeNull();
@ -442,7 +442,7 @@ public class DefaultPersistedGrantServiceTests
CreationTime = DateTime.UtcNow,
Lifetime = 10,
}, _ct);
await _subject.RemoveAllGrantsAsync("123", "client1", "session1");
await _subject.RemoveAllGrantsAsync("123", "client1", "session1", _ct);
(await _refreshTokens.GetRefreshTokenAsync(handle1, _ct)).ShouldBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle2, _ct)).ShouldNotBeNull();
@ -490,7 +490,7 @@ public class DefaultPersistedGrantServiceTests
CreationTime = DateTime.UtcNow,
Lifetime = 10,
}, _ct);
await _subject.RemoveAllGrantsAsync("123", sessionId: "session1");
await _subject.RemoveAllGrantsAsync("123", sessionId: "session1", ct: _ct);
(await _refreshTokens.GetRefreshTokenAsync(handle1, _ct)).ShouldBeNull();
(await _refreshTokens.GetRefreshTokenAsync(handle2, _ct)).ShouldBeNull();

View file

@ -354,7 +354,7 @@ public class KeyManagerTests
{
var id = CreateCacheAndStoreKey();
var keys = await _subject.GetAllKeysFromCacheAsync();
var keys = await _subject.GetAllKeysFromCacheAsync(_ct);
keys.Count().ShouldBe(1);
keys.Single().Id.ShouldBe(id);
@ -543,13 +543,13 @@ public class KeyManagerTests
public async Task CacheKeysAsync_should_not_store_empty_keys()
{
{
await _subject.CacheKeysAsync(null);
await _subject.CacheKeysAsync(null, _ct);
_mockKeyStoreCache.StoreKeysAsyncWasCalled.ShouldBeFalse();
}
{
await _subject.CacheKeysAsync(new RsaKeyContainer[0]);
await _subject.CacheKeysAsync(new RsaKeyContainer[0], _ct);
_mockKeyStoreCache.StoreKeysAsyncWasCalled.ShouldBeFalse();
}
@ -561,7 +561,7 @@ public class KeyManagerTests
var key1 = CreateKey(_options.KeyManagement.PropagationTime.Add(TimeSpan.FromMinutes(5)));
var key2 = CreateKey(_options.KeyManagement.PropagationTime.Add(TimeSpan.FromMinutes(10)));
await _subject.CacheKeysAsync(new[] { key1, key2 });
await _subject.CacheKeysAsync(new[] { key1, key2 }, _ct);
_mockKeyStoreCache.StoreKeysAsyncWasCalled.ShouldBeTrue();
_mockKeyStoreCache.StoreKeysAsyncDuration.ShouldBe(_options.KeyManagement.KeyCacheDuration);
@ -574,7 +574,7 @@ public class KeyManagerTests
{
var key1 = CreateKey();
await _subject.CacheKeysAsync(new[] { key1 });
await _subject.CacheKeysAsync(new[] { key1 }, _ct);
_mockKeyStoreCache.StoreKeysAsyncWasCalled.ShouldBeTrue();
_mockKeyStoreCache.StoreKeysAsyncDuration.ShouldBe(_options.KeyManagement.InitializationKeyCacheDuration);

View file

@ -10,6 +10,8 @@ namespace IdentityServer.UnitTests.Services;
public class DiagnosticDataServiceTests
{
private readonly CT _ct = TestContext.Current.CancellationToken;
[Fact]
public async Task GetJsonBytesAsync_WithNoEntries_ShouldReturnEmptyJsonObject()
{
@ -17,7 +19,7 @@ public class DiagnosticDataServiceTests
var entries = new List<IDiagnosticEntry>();
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
json.ShouldBe("{}");
@ -33,7 +35,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
var jsonDoc = JsonDocument.Parse(json);
@ -52,7 +54,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
var jsonDoc = JsonDocument.Parse(json);
@ -72,7 +74,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
await service.GetJsonBytesAsync();
await service.GetJsonBytesAsync(_ct);
capturedContext.Context.ShouldNotBeNull();
capturedContext.Context.ServerStartTime.ShouldBe(serverStartTime);
@ -90,7 +92,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
json.ShouldNotContain("\n");
@ -105,7 +107,7 @@ public class DiagnosticDataServiceTests
var entries = new List<IDiagnosticEntry>();
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonStringAsync();
var result = await service.GetJsonStringAsync(_ct);
result.ShouldBe("{}");
}
@ -120,7 +122,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonStringAsync();
var result = await service.GetJsonStringAsync(_ct);
var jsonDoc = JsonDocument.Parse(result);
jsonDoc.RootElement.GetProperty("TestProperty").GetString().ShouldBe("TestValue");
@ -138,7 +140,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonStringAsync();
var result = await service.GetJsonStringAsync(_ct);
var jsonDoc = JsonDocument.Parse(result);
jsonDoc.RootElement.GetProperty("Property1").GetString().ShouldBe("Value1");
@ -156,7 +158,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonStringAsync();
var result = await service.GetJsonStringAsync(_ct);
var jsonDoc = JsonDocument.Parse(result);
jsonDoc.RootElement.GetProperty("Property").GetString().ShouldBe("Value with émojis 🎉");
@ -173,8 +175,8 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var stringResult = await service.GetJsonStringAsync();
var bytesResult = await service.GetJsonBytesAsync();
var stringResult = await service.GetJsonStringAsync(_ct);
var bytesResult = await service.GetJsonBytesAsync(_ct);
var stringFromBytes = Encoding.UTF8.GetString(bytesResult.Span);
stringResult.ShouldBe(stringFromBytes);
@ -190,7 +192,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
var jsonDoc = JsonDocument.Parse(json);
@ -210,7 +212,7 @@ public class DiagnosticDataServiceTests
};
var service = new DiagnosticDataService(serverStartTime, entries);
var result = await service.GetJsonBytesAsync();
var result = await service.GetJsonBytesAsync(_ct);
var json = Encoding.UTF8.GetString(result.Span);
var jsonDoc = JsonDocument.Parse(json);