From b13446109d9663ccabef07995eb25cf9dff34f37 Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Mon, 12 Jan 2026 04:30:53 -0500 Subject: [PATCH] feat(sdk-rs): builder pattern, multiple endpoints & circuit breaker (#7379) --- .changeset/busy-cloths-search.md | 56 +++ .changeset/light-walls-vanish.md | 20 ++ .changeset/violet-waves-happen.md | 17 + .github/workflows/tests-integration.yaml | 1 + configs/cargo/Cargo.lock | 62 +++- .../tests/apollo-router/apollo-router.test.ts | 64 ++-- packages/libraries/router/Cargo.toml | 1 + packages/libraries/router/src/main.rs | 1 + .../router/src/persisted_documents.rs | 78 ++++- packages/libraries/router/src/registry.rs | 43 +-- packages/libraries/router/src/usage.rs | 259 ++++++-------- packages/libraries/sdk-rs/Cargo.toml | 9 +- packages/libraries/sdk-rs/src/agent/buffer.rs | 39 +++ .../libraries/sdk-rs/src/agent/builder.rs | 229 +++++++++++++ packages/libraries/sdk-rs/src/agent/mod.rs | 4 + .../src/{agent.rs => agent/usage_agent.rs} | 318 +++++++---------- .../sdk-rs/src/{graphql.rs => agent/utils.rs} | 0 .../libraries/sdk-rs/src/circuit_breaker.rs | 67 ++++ packages/libraries/sdk-rs/src/lib.rs | 2 +- .../sdk-rs/src/persisted_documents.rs | 321 +++++++++++++----- .../sdk-rs/src/supergraph_fetcher.rs | 260 -------------- .../sdk-rs/src/supergraph_fetcher/async_.rs | 141 ++++++++ .../sdk-rs/src/supergraph_fetcher/builder.rs | 135 ++++++++ .../sdk-rs/src/supergraph_fetcher/mod.rs | 51 +++ .../sdk-rs/src/supergraph_fetcher/sync.rs | 191 +++++++++++ 25 files changed, 1619 insertions(+), 750 deletions(-) create mode 100644 .changeset/busy-cloths-search.md create mode 100644 .changeset/light-walls-vanish.md create mode 100644 .changeset/violet-waves-happen.md create mode 100644 packages/libraries/sdk-rs/src/agent/buffer.rs create mode 100644 packages/libraries/sdk-rs/src/agent/builder.rs create mode 100644 packages/libraries/sdk-rs/src/agent/mod.rs rename packages/libraries/sdk-rs/src/{agent.rs => agent/usage_agent.rs} (62%) rename packages/libraries/sdk-rs/src/{graphql.rs => agent/utils.rs} (100%) create mode 100644 packages/libraries/sdk-rs/src/circuit_breaker.rs delete mode 100644 packages/libraries/sdk-rs/src/supergraph_fetcher.rs create mode 100644 packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs create mode 100644 packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs create mode 100644 packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs create mode 100644 packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs diff --git a/.changeset/busy-cloths-search.md b/.changeset/busy-cloths-search.md new file mode 100644 index 000000000..7aed840f7 --- /dev/null +++ b/.changeset/busy-cloths-search.md @@ -0,0 +1,56 @@ +--- +'hive-console-sdk-rs': minor +--- + +Breaking Changes to avoid future breaking changes; + +Switch to [Builder](https://rust-unofficial.github.io/patterns/patterns/creational/builder.html) pattern for `SupergraphFetcher`, `PersistedDocumentsManager` and `UsageAgent` structs. + +No more `try_new` or `try_new_async` or `try_new_sync` functions, instead use `SupergraphFetcherBuilder`, `PersistedDocumentsManagerBuilder` and `UsageAgentBuilder` structs to create instances. + +Benefits; + +- No need to provide all parameters at once when creating an instance even for default values. + +Example; +```rust +// Before +let fetcher = SupergraphFetcher::try_new_async( + "SOME_ENDPOINT", // endpoint + "SOME_KEY", + "MyUserAgent/1.0".to_string(), + Duration::from_secs(5), // connect_timeout + Duration::from_secs(10), // request_timeout + false, // accept_invalid_certs + 3, // retry_count + )?; + +// After +// No need to provide all parameters at once, can use default values +let fetcher = SupergraphFetcherBuilder::new() + .endpoint("SOME_ENDPOINT".to_string()) + .key("SOME_KEY".to_string()) + .build_async()?; +``` + +- Easier to add new configuration options in the future without breaking existing code. + +Example; + +```rust +let fetcher = SupergraphFetcher::try_new_async( + "SOME_ENDPOINT", // endpoint + "SOME_KEY", + "MyUserAgent/1.0".to_string(), + Duration::from_secs(5), // connect_timeout + Duration::from_secs(10), // request_timeout + false, // accept_invalid_certs + 3, // retry_count + circuit_breaker_config, // Breaking Change -> new parameter added + )?; + +let fetcher = SupergraphFetcherBuilder::new() + .endpoint("SOME_ENDPOINT".to_string()) + .key("SOME_KEY".to_string()) + .build_async()?; // No breaking change, circuit_breaker_config can be added later if needed +``` \ No newline at end of file diff --git a/.changeset/light-walls-vanish.md b/.changeset/light-walls-vanish.md new file mode 100644 index 000000000..7a1f4f891 --- /dev/null +++ b/.changeset/light-walls-vanish.md @@ -0,0 +1,20 @@ +--- +'hive-console-sdk-rs': patch +--- + +Circuit Breaker Implementation and Multiple Endpoints Support + +Implementation of Circuit Breakers in Hive Console Rust SDK, you can learn more [here](https://the-guild.dev/graphql/hive/product-updates/2025-12-04-cdn-mirror-and-circuit-breaker) + +Breaking Changes: + +Now `endpoint` configuration accepts multiple endpoints as an array for `SupergraphFetcherBuilder` and `PersistedDocumentsManager`. + +```diff +SupergraphFetcherBuilder::default() +- .endpoint(endpoint) ++ .add_endpoint(endpoint1) ++ .add_endpoint(endpoint2) +``` + +This change requires updating the configuration structure to accommodate multiple endpoints. diff --git a/.changeset/violet-waves-happen.md b/.changeset/violet-waves-happen.md new file mode 100644 index 000000000..67aabe639 --- /dev/null +++ b/.changeset/violet-waves-happen.md @@ -0,0 +1,17 @@ +--- +'hive-apollo-router-plugin': major +--- + +- Multiple endpoints support for `HiveRegistry` and `PersistedOperationsPlugin` + +Breaking Changes: +- Now there is no `endpoint` field in the configuration, it has been replaced with `endpoints`, which is an array of strings. You are not affected if you use environment variables to set the endpoint. + +```diff +HiveRegistry::new( + Some( + HiveRegistryConfig { +- endpoint: String::from("CDN_ENDPOINT"), ++ endpoints: vec![String::from("CDN_ENDPOINT1"), String::from("CDN_ENDPOINT2")], + ) +) diff --git a/.github/workflows/tests-integration.yaml b/.github/workflows/tests-integration.yaml index 8357ef98e..5b974cd65 100644 --- a/.github/workflows/tests-integration.yaml +++ b/.github/workflows/tests-integration.yaml @@ -23,6 +23,7 @@ jobs: integration: runs-on: ubuntu-22.04 strategy: + fail-fast: false matrix: # Divide integration tests into 3 shards, to run them in parallel. shardIndex: [1, 2, 3, 'apollo-router'] diff --git a/configs/cargo/Cargo.lock b/configs/cargo/Cargo.lock index 0ae529044..aa72fdc58 100644 --- a/configs/cargo/Cargo.lock +++ b/configs/cargo/Cargo.lock @@ -122,9 +122,9 @@ dependencies = [ [[package]] name = "apollo-federation" -version = "2.10.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "034b556462798d1f076d865791410c088fd549c920f2a5c0d83a0611141559b4" +checksum = "6b26138e298ecff0fb0b972175a93174d7d87d74cdc646af6a7932dc435d86c8" dependencies = [ "apollo-compiler", "derive_more", @@ -171,9 +171,9 @@ dependencies = [ [[package]] name = "apollo-router" -version = "2.10.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db37c08af3b5844565e98567e92e68ae69ecd6bb4d3e674bb2a2f27d865d3680" +checksum = "4d5b4da5efb2fb56a7076ded5912eb829005d1c3c2fec35359855f689db3fc57" dependencies = [ "addr2line", "ahash", @@ -292,6 +292,7 @@ dependencies = [ "similar", "static_assertions", "strum", + "strum_macros", "sys-info", "sysinfo", "thiserror 2.0.17", @@ -436,6 +437,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-dropper-simple" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7c4748dfe8cd3d625ec68fc424fa80c134319881185866f9e173af9e5d8add8" +dependencies = [ + "async-scoped", + "async-trait", + "futures", + "rustc_version", + "tokio", +] + [[package]] name = "async-executor" version = "1.13.3" @@ -521,6 +535,17 @@ dependencies = [ "rustix", ] +[[package]] +name = "async-scoped" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740" +dependencies = [ + "futures", + "pin-project", + "tokio", +] + [[package]] name = "async-signal" version = "0.2.13" @@ -2596,6 +2621,7 @@ dependencies = [ "serde_json", "sha2", "tokio", + "tokio-util", "tower 0.5.2", "tracing", ] @@ -2605,17 +2631,24 @@ name = "hive-console-sdk" version = "0.2.3" dependencies = [ "anyhow", + "async-dropper-simple", "async-trait", "axum-core 0.5.5", + "futures-util", "graphql-parser", "graphql-tools", + "lazy_static", "md5", "mockito", "moka", + "once_cell", + "recloser", + "regex-automata", "regress", "reqwest", "reqwest-middleware", "reqwest-retry", + "retry-policies", "serde", "serde_json", "sha2", @@ -3531,9 +3564,9 @@ dependencies = [ [[package]] name = "mockall" -version = "0.14.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f58d964098a5f9c6b63d0798e5372fd04708193510a7af313c22e9f29b7b620b" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" dependencies = [ "cfg-if", "downcast", @@ -3545,9 +3578,9 @@ dependencies = [ [[package]] name = "mockall_derive" -version = "0.14.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca41ce716dda6a9be188b385aa78ee5260fc25cd3802cb2a8afdc6afbe6b6dbf" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" dependencies = [ "cfg-if", "proc-macro2", @@ -4589,6 +4622,16 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "recloser" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ac0d06281c3556fea72cef9e5372d9ac172335be0d71c3b4f3db900483e0eb" +dependencies = [ + "crossbeam-epoch", + "pin-project", +] + [[package]] name = "redis-protocol" version = "6.0.0" @@ -5533,9 +5576,6 @@ name = "strum" version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" -dependencies = [ - "strum_macros", -] [[package]] name = "strum_macros" diff --git a/integration-tests/tests/apollo-router/apollo-router.test.ts b/integration-tests/tests/apollo-router/apollo-router.test.ts index 980d26a9c..f8bd227b2 100644 --- a/integration-tests/tests/apollo-router/apollo-router.test.ts +++ b/integration-tests/tests/apollo-router/apollo-router.test.ts @@ -2,14 +2,13 @@ import { existsSync, rmSync, writeFileSync } from 'node:fs'; import { createServer } from 'node:http'; import { tmpdir } from 'node:os'; import { join } from 'node:path'; +import { MaybePromise } from 'slonik/dist/src/types'; import { ProjectType } from 'testkit/gql/graphql'; import { initSeed } from 'testkit/seed'; import { getServiceHost } from 'testkit/utils'; import { execa } from '@esm2cjs/execa'; describe('Apollo Router Integration', () => { - const getBaseEndpoint = () => - getServiceHost('server', 8082).then(v => `http://${v}/artifacts/v1/`); const getAvailablePort = () => new Promise(resolve => { const server = createServer(); @@ -18,12 +17,19 @@ describe('Apollo Router Integration', () => { if (address && typeof address === 'object') { const port = address.port; server.close(() => resolve(port)); + } else { + throw new Error('Could not get available port'); } }); }); + function defer(deferFn: () => MaybePromise) { + return { + async [Symbol.asyncDispose]() { + return deferFn(); + }, + }; + } it('fetches the supergraph and sends usage reports', async () => { - const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`); - const endpointBaseUrl = await getBaseEndpoint(); const { createOrg } = await initSeed().createOwner(); const { createProject } = await createOrg(); const { createTargetAccessToken, createCdnAccess, target, waitForOperationsCollected } = @@ -50,6 +56,7 @@ describe('Apollo Router Integration', () => { .then(r => r.expectNoGraphQLErrors()); expect(publishSchemaResult.schemaPublish.__typename).toBe('SchemaPublishSuccess'); + const cdnAccessResult = await createCdnAccess(); const usageAddress = await getServiceHost('usage', 8081); @@ -60,6 +67,7 @@ describe('Apollo Router Integration', () => { `Apollo Router binary not found at path: ${routerBinPath}, make sure to build it first with 'cargo build'`, ); } + const routerPort = await getAvailablePort(); const routerConfigContent = ` supergraph: @@ -67,17 +75,22 @@ supergraph: plugins: hive.usage: {} `.trim(); + const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`); writeFileSync(routerConfigPath, routerConfigContent, 'utf-8'); + + const cdnEndpoint = await getServiceHost('server', 8082).then( + v => `http://${v}/artifacts/v1/${target.id}`, + ); const routerProc = execa(routerBinPath, ['--dev', '--config', routerConfigPath], { all: true, env: { - HIVE_CDN_ENDPOINT: endpointBaseUrl + target.id, + HIVE_CDN_ENDPOINT: cdnEndpoint, HIVE_CDN_KEY: cdnAccessResult.secretAccessToken, HIVE_ENDPOINT: `http://${usageAddress}`, HIVE_TOKEN: writeToken.secret, - HIVE_TARGET_ID: target.id, }, }); + let log = ''; await new Promise((resolve, reject) => { routerProc.catch(err => { if (!err.isCanceled) { @@ -88,17 +101,22 @@ plugins: if (!routerProcOut) { return reject(new Error('No stdout from Apollo Router process')); } - let log = ''; routerProcOut.on('data', data => { log += data.toString(); if (log.includes('GraphQL endpoint exposed at')) { resolve(true); } + process.stdout.write(log); }); }); - try { - const url = `http://localhost:${routerPort}/`; + await using _ = defer(() => { + rmSync(routerConfigPath); + routerProc.cancel(); + }); + + const url = `http://localhost:${routerPort}/`; + async function sendOperation(i: number) { const response = await fetch(url, { method: 'POST', headers: { @@ -107,13 +125,13 @@ plugins: }, body: JSON.stringify({ query: ` - query TestQuery { - me { - id - name - } - } - `, + query Query${i} { + me { + id + name + } + } + `, }), }); @@ -127,10 +145,16 @@ plugins: }, }, }); - await waitForOperationsCollected(1); - } finally { - routerProc.cancel(); - rmSync(routerConfigPath); } + const cnt = 1000; + const jobs = []; + for (let i = 0; i < cnt; i++) { + if (i % 100 === 0) { + await new Promise(res => setTimeout(res, 500)); + } + jobs.push(sendOperation(i)); + } + await Promise.all(jobs); + await waitForOperationsCollected(cnt); }); }); diff --git a/packages/libraries/router/Cargo.toml b/packages/libraries/router/Cargo.toml index b419f9134..7569b2619 100644 --- a/packages/libraries/router/Cargo.toml +++ b/packages/libraries/router/Cargo.toml @@ -35,6 +35,7 @@ http = "1" http-body-util = "0.1" graphql-parser = "0.4.1" rand = "0.9.0" +tokio-util = "0.7.16" [dev-dependencies] httpmock = "0.7.0" diff --git a/packages/libraries/router/src/main.rs b/packages/libraries/router/src/main.rs index 491a31ca4..910c4cf5f 100644 --- a/packages/libraries/router/src/main.rs +++ b/packages/libraries/router/src/main.rs @@ -21,6 +21,7 @@ fn main() { register_plugins(); // Initialize the Hive Registry and start the Apollo Router + // TODO: Look at builder pattern in Executable::builder().start() match HiveRegistry::new(None).and(apollo_router::main()) { Ok(_) => {} Err(e) => { diff --git a/packages/libraries/router/src/persisted_documents.rs b/packages/libraries/router/src/persisted_documents.rs index 2bd8d1057..160a67c3d 100644 --- a/packages/libraries/router/src/persisted_documents.rs +++ b/packages/libraries/router/src/persisted_documents.rs @@ -32,7 +32,7 @@ pub static PERSISTED_DOCUMENT_HASH_KEY: &str = "hive::persisted_document_hash"; pub struct Config { pub enabled: Option, /// GraphQL Hive persisted documents CDN endpoint URL. - pub endpoint: Option, + pub endpoint: Option, /// GraphQL Hive persisted documents CDN access token. pub key: Option, /// Whether arbitrary documents should be allowed along-side persisted documents. @@ -57,6 +57,25 @@ pub struct Config { pub cache_size: Option, } +#[derive(Clone, Debug, Deserialize, JsonSchema)] +#[serde(untagged)] +pub enum EndpointConfig { + Single(String), + Multiple(Vec), +} + +impl From<&str> for EndpointConfig { + fn from(value: &str) -> Self { + EndpointConfig::Single(value.into()) + } +} + +impl From<&[&str]> for EndpointConfig { + fn from(value: &[&str]) -> Self { + EndpointConfig::Multiple(value.iter().map(|s| s.to_string()).collect()) + } +} + pub struct PersistedDocumentsPlugin { persisted_documents_manager: Option>, allow_arbitrary_documents: bool, @@ -72,11 +91,14 @@ impl PersistedDocumentsPlugin { allow_arbitrary_documents, }); } - let endpoint = match &config.endpoint { - Some(ep) => ep.clone(), + let endpoints = match &config.endpoint { + Some(ep) => match ep { + EndpointConfig::Single(url) => vec![url.clone()], + EndpointConfig::Multiple(urls) => urls.clone(), + }, None => { if let Ok(ep) = env::var("HIVE_CDN_ENDPOINT") { - ep + vec![ep] } else { return Err( "Endpoint for persisted documents CDN is not configured. Please set it via the plugin configuration or HIVE_CDN_ENDPOINT environment variable." @@ -100,17 +122,41 @@ impl PersistedDocumentsPlugin { } }; + let mut persisted_documents_manager = PersistedDocumentsManager::builder() + .key(key) + .user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)); + + for endpoint in endpoints { + persisted_documents_manager = persisted_documents_manager.add_endpoint(endpoint); + } + + if let Some(connect_timeout) = config.connect_timeout { + persisted_documents_manager = + persisted_documents_manager.connect_timeout(Duration::from_secs(connect_timeout)); + } + + if let Some(request_timeout) = config.request_timeout { + persisted_documents_manager = + persisted_documents_manager.request_timeout(Duration::from_secs(request_timeout)); + } + + if let Some(retry_count) = config.retry_count { + persisted_documents_manager = persisted_documents_manager.max_retries(retry_count); + } + + if let Some(accept_invalid_certs) = config.accept_invalid_certs { + persisted_documents_manager = + persisted_documents_manager.accept_invalid_certs(accept_invalid_certs); + } + + if let Some(cache_size) = config.cache_size { + persisted_documents_manager = persisted_documents_manager.cache_size(cache_size); + } + + let persisted_documents_manager = persisted_documents_manager.build()?; + Ok(PersistedDocumentsPlugin { - persisted_documents_manager: Some(Arc::new(PersistedDocumentsManager::new( - key, - endpoint, - config.accept_invalid_certs.unwrap_or(false), - Duration::from_secs(config.connect_timeout.unwrap_or(5)), - Duration::from_secs(config.request_timeout.unwrap_or(15)), - config.retry_count.unwrap_or(3), - config.cache_size.unwrap_or(1000), - format!("hive-apollo-router/{}", PLUGIN_VERSION), - ))), + persisted_documents_manager: Some(Arc::new(persisted_documents_manager)), allow_arbitrary_documents, }) } @@ -344,8 +390,8 @@ mod hive_persisted_documents_tests { Self { server } } - fn endpoint(&self) -> String { - self.server.url("") + fn endpoint(&self) -> EndpointConfig { + EndpointConfig::Single(self.server.url("")) } /// Registers a valid artifact URL with an actual GraphQL document diff --git a/packages/libraries/router/src/registry.rs b/packages/libraries/router/src/registry.rs index 243c160cb..fb48b76bc 100644 --- a/packages/libraries/router/src/registry.rs +++ b/packages/libraries/router/src/registry.rs @@ -1,14 +1,13 @@ use crate::consts::PLUGIN_VERSION; use crate::registry_logger::Logger; use anyhow::{anyhow, Result}; +use hive_console_sdk::supergraph_fetcher::sync::SupergraphFetcherSyncState; use hive_console_sdk::supergraph_fetcher::SupergraphFetcher; -use hive_console_sdk::supergraph_fetcher::SupergraphFetcherSyncState; use sha2::Digest; use sha2::Sha256; use std::env; use std::io::Write; use std::thread; -use std::time::Duration; #[derive(Debug)] pub struct HiveRegistry { @@ -18,7 +17,7 @@ pub struct HiveRegistry { } pub struct HiveRegistryConfig { - endpoint: Option, + endpoints: Vec, key: Option, poll_interval: Option, accept_invalid_certs: Option, @@ -29,7 +28,7 @@ impl HiveRegistry { #[allow(clippy::new_ret_no_self)] pub fn new(user_config: Option) -> Result<()> { let mut config = HiveRegistryConfig { - endpoint: None, + endpoints: vec![], key: None, poll_interval: None, accept_invalid_certs: Some(true), @@ -38,7 +37,7 @@ impl HiveRegistry { // Pass values from user's config if let Some(user_config) = user_config { - config.endpoint = user_config.endpoint; + config.endpoints = user_config.endpoints; config.key = user_config.key; config.poll_interval = user_config.poll_interval; config.accept_invalid_certs = user_config.accept_invalid_certs; @@ -47,9 +46,9 @@ impl HiveRegistry { // Pass values from environment variables if they are not set in the user's config - if config.endpoint.is_none() { + if config.endpoints.is_empty() { if let Ok(endpoint) = env::var("HIVE_CDN_ENDPOINT") { - config.endpoint = Some(endpoint); + config.endpoints.push(endpoint); } } @@ -86,7 +85,7 @@ impl HiveRegistry { } // Resolve values - let endpoint = config.endpoint.unwrap_or_default(); + let endpoint = config.endpoints; let key = config.key.unwrap_or_default(); let poll_interval: u64 = config.poll_interval.unwrap_or(10); let accept_invalid_certs = config.accept_invalid_certs.unwrap_or(false); @@ -120,19 +119,23 @@ impl HiveRegistry { .to_string_lossy() .to_string(), ); - env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone()); - env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true"); + unsafe { + env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone()); + env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true"); + } - let fetcher = SupergraphFetcher::try_new_sync( - endpoint, - &key, - format!("hive-apollo-router/{}", PLUGIN_VERSION), - Duration::from_secs(5), - Duration::from_secs(60), - accept_invalid_certs, - 3, - ) - .map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?; + let mut fetcher = SupergraphFetcher::builder() + .key(key) + .user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)) + .accept_invalid_certs(accept_invalid_certs); + + for ep in endpoint { + fetcher = fetcher.add_endpoint(ep); + } + + let fetcher = fetcher + .build_sync() + .map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?; let registry = HiveRegistry { fetcher, diff --git a/packages/libraries/router/src/usage.rs b/packages/libraries/router/src/usage.rs index 1964a6e04..d8eadcd1d 100644 --- a/packages/libraries/router/src/usage.rs +++ b/packages/libraries/router/src/usage.rs @@ -8,8 +8,8 @@ use core::ops::Drop; use futures::StreamExt; use graphql_parser::parse_schema; use graphql_parser::schema::Document; -use hive_console_sdk::agent::UsageAgentExt; -use hive_console_sdk::agent::{ExecutionReport, UsageAgent}; +use hive_console_sdk::agent::usage_agent::UsageAgentExt; +use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent}; use http::HeaderValue; use rand::Rng; use schemars::JsonSchema; @@ -19,6 +19,7 @@ use std::env; use std::sync::Arc; use std::time::{Duration, Instant}; use std::time::{SystemTime, UNIX_EPOCH}; +use tokio_util::sync::CancellationToken; use tower::BoxError; use tower::ServiceBuilder; use tower::ServiceExt; @@ -47,11 +48,12 @@ struct OperationConfig { pub struct UsagePlugin { config: OperationConfig, - agent: Option>, + agent: Option, schema: Arc>, + cancellation_token: Arc, } -#[derive(Clone, Debug, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Deserialize, JsonSchema, Default)] pub struct Config { /// Default: true enabled: Option, @@ -95,26 +97,6 @@ pub struct Config { flush_interval: Option, } -impl Default for Config { - fn default() -> Self { - Self { - enabled: Some(true), - registry_token: None, - registry_usage_endpoint: Some(DEFAULT_HIVE_USAGE_ENDPOINT.into()), - sample_rate: Some(1.0), - exclude: None, - client_name_header: Some(String::from("graphql-client-name")), - client_version_header: Some(String::from("graphql-client-version")), - accept_invalid_certs: Some(false), - buffer_size: Some(1000), - connect_timeout: Some(5), - request_timeout: Some(15), - flush_interval: Some(5), - target: None, - } - } -} - impl UsagePlugin { fn populate_context(config: OperationConfig, req: &supergraph::Request) { let context = &req.context; @@ -179,108 +161,95 @@ impl UsagePlugin { } } -static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage"; - #[async_trait::async_trait] impl Plugin for UsagePlugin { type Config = Config; async fn new(init: PluginInit) -> Result { - let token = init - .config - .registry_token - .clone() - .or_else(|| env::var("HIVE_TOKEN").ok()); - - if token.is_none() { - return Err("Hive token is required".into()); - } - - let endpoint = init - .config - .registry_usage_endpoint - .clone() - .unwrap_or_else(|| { - env::var("HIVE_ENDPOINT").unwrap_or(DEFAULT_HIVE_USAGE_ENDPOINT.to_string()) - }); - - let target_id = init - .config - .target - .clone() - .or_else(|| env::var("HIVE_TARGET_ID").ok()); - - let default_config = Config::default(); let user_config = init.config; - let enabled = user_config - .enabled - .or(default_config.enabled) - .expect("enabled has default value"); - let buffer_size = user_config - .buffer_size - .or(default_config.buffer_size) - .expect("buffer_size has no default value"); - let accept_invalid_certs = user_config - .accept_invalid_certs - .or(default_config.accept_invalid_certs) - .expect("accept_invalid_certs has no default value"); - let connect_timeout = user_config - .connect_timeout - .or(default_config.connect_timeout) - .expect("connect_timeout has no default value"); - let request_timeout = user_config - .request_timeout - .or(default_config.request_timeout) - .expect("request_timeout has no default value"); - let flush_interval = user_config - .flush_interval - .or(default_config.flush_interval) - .expect("request_timeout has no default value"); + + let enabled = user_config.enabled.unwrap_or(true); if enabled { tracing::info!("Starting GraphQL Hive Usage plugin"); } - let schema = parse_schema(&init.supergraph_sdl) - .expect("Failed to parse schema") - .into_static(); + + let cancellation_token = Arc::new(CancellationToken::new()); let agent = if enabled { - let flush_interval = Duration::from_secs(flush_interval); - let agent = UsageAgent::try_new( - &token.expect("token is set"), - endpoint, - target_id, - buffer_size, - Duration::from_secs(connect_timeout), - Duration::from_secs(request_timeout), - accept_invalid_certs, - flush_interval, - format!("hive-apollo-router/{}", PLUGIN_VERSION), - ) - .map_err(Box::new)?; - start_flush_interval(agent.clone()); + let mut agent = + UsageAgent::builder().user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION)); + + if let Some(endpoint) = user_config.registry_usage_endpoint { + agent = agent.endpoint(endpoint); + } else if let Ok(env_endpoint) = env::var("HIVE_ENDPOINT") { + agent = agent.endpoint(env_endpoint); + } + + if let Some(token) = user_config.registry_token { + agent = agent.token(token); + } else if let Ok(env_token) = env::var("HIVE_TOKEN") { + agent = agent.token(env_token); + } + + if let Some(target_id) = user_config.target { + agent = agent.target_id(target_id); + } else if let Ok(env_target) = env::var("HIVE_TARGET_ID") { + agent = agent.target_id(env_target); + } + + if let Some(buffer_size) = user_config.buffer_size { + agent = agent.buffer_size(buffer_size); + } + + if let Some(connect_timeout) = user_config.connect_timeout { + agent = agent.connect_timeout(Duration::from_secs(connect_timeout)); + } + + if let Some(request_timeout) = user_config.request_timeout { + agent = agent.request_timeout(Duration::from_secs(request_timeout)); + } + + if let Some(accept_invalid_certs) = user_config.accept_invalid_certs { + agent = agent.accept_invalid_certs(accept_invalid_certs); + } + + if let Some(flush_interval) = user_config.flush_interval { + agent = agent.flush_interval(Duration::from_secs(flush_interval)); + } + + let agent = agent.build().map_err(Box::new)?; + + let cancellation_token_for_interval = cancellation_token.clone(); + let agent_for_interval = agent.clone(); + tokio::task::spawn(async move { + agent_for_interval + .start_flush_interval(&cancellation_token_for_interval) + .await; + }); Some(agent) } else { None }; + + let schema = parse_schema(&init.supergraph_sdl) + .expect("Failed to parse schema") + .into_static(); + Ok(UsagePlugin { schema: Arc::new(schema), config: OperationConfig { - sample_rate: user_config - .sample_rate - .or(default_config.sample_rate) - .expect("sample_rate has no default value"), - exclude: user_config.exclude.or(default_config.exclude), + sample_rate: user_config.sample_rate.unwrap_or(1.0), + exclude: user_config.exclude, client_name_header: user_config .client_name_header - .or(default_config.client_name_header) - .expect("client_name_header has no default value"), + .unwrap_or("graphql-client-name".to_string()), client_version_header: user_config .client_version_header - .or(default_config.client_version_header) - .expect("client_version_header has no default value"), + .unwrap_or("graphql-client-version".to_string()), }, agent, + cancellation_token, }) } @@ -342,58 +311,61 @@ impl Plugin for UsagePlugin { match result { Err(e) => { - agent - .add_report(ExecutionReport { - schema, - client_name, - client_version, - timestamp, - duration, - ok: false, - errors: 1, - operation_body, - operation_name, - persisted_document_hash, - }) - .unwrap_or_else(|e| { + tokio::spawn(async move { + let res = agent + .add_report(ExecutionReport { + schema, + client_name, + client_version, + timestamp, + duration, + ok: false, + errors: 1, + operation_body, + operation_name, + persisted_document_hash, + }) + .await; + if let Err(e) = res { tracing::error!("Error adding report: {}", e); - }); + } + }); Err(e) } Ok(router_response) => { let is_failure = !router_response.response.status().is_success(); Ok(router_response.map(move |response_stream| { - let client_name = client_name.clone(); - let client_version = client_version.clone(); - let operation_body = operation_body.clone(); - let operation_name = operation_name.clone(); - let res = response_stream .map(move |response| { // make sure we send a single report, not for each chunk let response_has_errors = !response.errors.is_empty(); - agent - .add_report(ExecutionReport { - schema: schema.clone(), - client_name: client_name.clone(), - client_version: client_version.clone(), - timestamp, - duration, - ok: !is_failure && !response_has_errors, - errors: response.errors.len(), - operation_body: operation_body.clone(), - operation_name: operation_name.clone(), - persisted_document_hash: - persisted_document_hash.clone(), - }) - .unwrap_or_else(|e| { + let agent = agent.clone(); + let execution_report = ExecutionReport { + schema: schema.clone(), + client_name: client_name.clone(), + client_version: client_version.clone(), + timestamp, + duration, + ok: !is_failure && !response_has_errors, + errors: response.errors.len(), + operation_body: operation_body.clone(), + operation_name: operation_name.clone(), + persisted_document_hash: + persisted_document_hash.clone(), + }; + tokio::spawn(async move { + let res = agent + .add_report(execution_report) + .await; + if let Err(e) = res { tracing::error!( "Error adding report: {}", e ); - }); + } + }); response }) @@ -415,17 +387,11 @@ impl Plugin for UsagePlugin { impl Drop for UsagePlugin { fn drop(&mut self) { - tracing::debug!("UsagePlugin has been dropped!"); - // TODO: flush the buffer + self.cancellation_token.cancel(); + // Flush already done by UsageAgent's Drop impl } } -pub fn start_flush_interval(agent_for_interval: Arc) { - tokio::task::spawn(async move { - agent_for_interval.start_flush_interval(None).await; - }); -} - #[cfg(test)] mod hive_usage_tests { use apollo_router::{ @@ -479,7 +445,7 @@ mod hive_usage_tests { } fn wait_for_processing(&self) -> tokio::time::Sleep { - tokio::time::sleep(tokio::time::Duration::from_secs(1)) + tokio::time::sleep(tokio::time::Duration::from_secs(2)) } fn activate_usage_mock(&'_ self) -> Mock<'_> { @@ -584,6 +550,7 @@ mod hive_usage_tests { instance.execute_operation(req).await.next_response().await; instance.wait_for_processing().await; + println!("Waiting done"); mock.assert(); mock.assert_hits(1); diff --git a/packages/libraries/sdk-rs/Cargo.toml b/packages/libraries/sdk-rs/Cargo.toml index 580f5541c..f75437ec7 100644 --- a/packages/libraries/sdk-rs/Cargo.toml +++ b/packages/libraries/sdk-rs/Cargo.toml @@ -20,7 +20,7 @@ reqwest = { version = "0.12.24", default-features = false, features = [ "blocking", ] } reqwest-retry = "0.8.0" -reqwest-middleware = "0.4.2" +reqwest-middleware = { version = "0.4.2", features = ["json"]} anyhow = "1" tracing = "0.1" serde = "1" @@ -32,8 +32,15 @@ serde_json = "1" moka = { version = "0.12.10", features = ["future", "sync"] } sha2 = { version = "0.10.8", features = ["std"] } tokio-util = "0.7.16" +regex-automata = "0.4.10" +once_cell = "1.21.3" +retry-policies = "0.5.0" +recloser = "1.3.1" +futures-util = "0.3.31" typify = "0.5.0" regress = "0.10.5" +lazy_static = "1.5.0" +async-dropper-simple = { version = "0.2.6", features = ["tokio", "no-default-bound"] } [dev-dependencies] mockito = "1.7.0" diff --git a/packages/libraries/sdk-rs/src/agent/buffer.rs b/packages/libraries/sdk-rs/src/agent/buffer.rs new file mode 100644 index 000000000..5c7add3f4 --- /dev/null +++ b/packages/libraries/sdk-rs/src/agent/buffer.rs @@ -0,0 +1,39 @@ +use std::collections::VecDeque; + +use tokio::sync::Mutex; + +pub struct Buffer { + max_size: usize, + queue: Mutex>, +} + +pub enum AddStatus { + Full { drained: Vec }, + Ok, +} + +impl Buffer { + pub fn new(max_size: usize) -> Self { + Self { + queue: Mutex::new(VecDeque::with_capacity(max_size)), + max_size, + } + } + + pub async fn add(&self, item: T) -> AddStatus { + let mut queue = self.queue.lock().await; + if queue.len() >= self.max_size { + let mut drained: Vec = queue.drain(..).collect(); + drained.push(item); + AddStatus::Full { drained } + } else { + queue.push_back(item); + AddStatus::Ok + } + } + + pub async fn drain(&self) -> Vec { + let mut queue = self.queue.lock().await; + queue.drain(..).collect() + } +} diff --git a/packages/libraries/sdk-rs/src/agent/builder.rs b/packages/libraries/sdk-rs/src/agent/builder.rs new file mode 100644 index 000000000..a7831a2ac --- /dev/null +++ b/packages/libraries/sdk-rs/src/agent/builder.rs @@ -0,0 +1,229 @@ +use std::{sync::Arc, time::Duration}; + +use async_dropper_simple::AsyncDropper; +use once_cell::sync::Lazy; +use recloser::AsyncRecloser; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest_middleware::ClientBuilder; +use reqwest_retry::RetryTransientMiddleware; + +use crate::agent::buffer::Buffer; +use crate::agent::usage_agent::{non_empty_string, AgentError, UsageAgent, UsageAgentInner}; +use crate::agent::utils::OperationProcessor; +use crate::circuit_breaker; +use retry_policies::policies::ExponentialBackoff; + +pub struct UsageAgentBuilder { + token: Option, + endpoint: String, + target_id: Option, + buffer_size: usize, + connect_timeout: Duration, + request_timeout: Duration, + accept_invalid_certs: bool, + flush_interval: Duration, + retry_policy: ExponentialBackoff, + user_agent: Option, + circuit_breaker: Option, +} + +pub static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage"; + +impl Default for UsageAgentBuilder { + fn default() -> Self { + Self { + endpoint: DEFAULT_HIVE_USAGE_ENDPOINT.to_string(), + token: None, + target_id: None, + buffer_size: 1000, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(15), + accept_invalid_certs: false, + flush_interval: Duration::from_secs(5), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + user_agent: None, + circuit_breaker: None, + } + } +} + +fn is_legacy_token(token: &str) -> bool { + !token.starts_with("hvo1/") && !token.starts_with("hvu1/") && !token.starts_with("hvp1/") +} + +impl UsageAgentBuilder { + /// Your [Registry Access Token](https://the-guild.dev/graphql/hive/docs/management/targets#registry-access-tokens) with write permission. + pub fn token(mut self, token: String) -> Self { + if let Some(token) = non_empty_string(Some(token)) { + self.token = Some(token); + } + self + } + /// For self-hosting, you can override `/usage` endpoint (defaults to `https://app.graphql-hive.com/usage`). + pub fn endpoint(mut self, endpoint: String) -> Self { + if let Some(endpoint) = non_empty_string(Some(endpoint)) { + self.endpoint = endpoint; + } + self + } + /// A target ID, this can either be a slug following the format “$organizationSlug/$projectSlug/$targetSlug” (e.g “the-guild/graphql-hive/staging”) or an UUID (e.g. “a0f4c605-6541-4350-8cfe-b31f21a4bf80”). To be used when the token is configured with an organization access token. + pub fn target_id(mut self, target_id: String) -> Self { + if let Some(target_id) = non_empty_string(Some(target_id)) { + self.target_id = Some(target_id); + } + self + } + /// A maximum number of operations to hold in a buffer before sending to Hive Console + /// Default: 1000 + pub fn buffer_size(mut self, buffer_size: usize) -> Self { + self.buffer_size = buffer_size; + self + } + /// A timeout for only the connect phase of a request to Hive Console + /// Default: 5 seconds + pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + /// A timeout for the entire request to Hive Console + /// Default: 15 seconds + pub fn request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + /// Accepts invalid SSL certificates + /// Default: false + pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self { + self.accept_invalid_certs = accept_invalid_certs; + self + } + /// Frequency of flushing the buffer to the server + /// Default: 5 seconds + pub fn flush_interval(mut self, flush_interval: Duration) -> Self { + self.flush_interval = flush_interval; + self + } + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + if let Some(user_agent) = non_empty_string(Some(user_agent)) { + self.user_agent = Some(user_agent); + } + self + } + /// Retry policy for sending reports + /// Default: ExponentialBackoff with max 3 retries + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + /// Maximum number of retries for sending reports + /// Default: ExponentialBackoff with max 3 retries + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + pub(crate) fn build_agent(self) -> Result { + let mut default_headers = HeaderMap::new(); + + default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2")); + + let token = match self.token { + Some(token) => token, + None => return Err(AgentError::MissingToken), + }; + + let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token)) + .map_err(|_| AgentError::InvalidToken)?; + + authorization_header.set_sensitive(true); + + default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header); + + default_headers.insert( + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(default_headers); + + if let Some(user_agent) = &self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); + } + + let reqwest_agent = reqwest_agent + .build() + .map_err(AgentError::HTTPClientCreationError)?; + let client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + let mut endpoint = self.endpoint; + + match self.target_id { + Some(_) if is_legacy_token(&token) => return Err(AgentError::TargetIdWithLegacyToken), + Some(target_id) if !is_legacy_token(&token) => { + let target_id = validate_target_id(&target_id)?; + endpoint.push_str(&format!("/{}", target_id)); + } + None if !is_legacy_token(&token) => return Err(AgentError::MissingTargetId), + _ => {} + } + + let circuit_breaker = if let Some(cb) = self.circuit_breaker { + cb + } else { + circuit_breaker::CircuitBreakerBuilder::default() + .build_async() + .map_err(AgentError::CircuitBreakerCreationError)? + }; + + let buffer = Buffer::new(self.buffer_size); + + Ok(UsageAgentInner { + endpoint, + buffer, + processor: OperationProcessor::new(), + client, + flush_interval: self.flush_interval, + circuit_breaker, + }) + } + pub fn build(self) -> Result { + let agent = self.build_agent()?; + Ok(Arc::new(AsyncDropper::new(agent))) + } +} + +// Target ID regexp for validation: slug format +static SLUG_REGEX: Lazy = Lazy::new(|| { + regex_automata::meta::Regex::new(r"^[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+$").unwrap() +}); +// Target ID regexp for validation: UUID format +static UUID_REGEX: Lazy = Lazy::new(|| { + regex_automata::meta::Regex::new( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$", + ) + .unwrap() +}); + +fn validate_target_id(target_id: &str) -> Result<&str, AgentError> { + let trimmed_s = target_id.trim(); + if trimmed_s.is_empty() { + Err(AgentError::InvalidTargetId("".to_string())) + } else { + if SLUG_REGEX.is_match(trimmed_s) { + return Ok(trimmed_s); + } + if UUID_REGEX.is_match(trimmed_s) { + return Ok(trimmed_s); + } + Err(AgentError::InvalidTargetId(format!( + "Invalid target_id format: '{}'. It must be either in slug format '$organizationSlug/$projectSlug/$targetSlug' or UUID format 'a0f4c605-6541-4350-8cfe-b31f21a4bf80'", + trimmed_s + ))) + } +} diff --git a/packages/libraries/sdk-rs/src/agent/mod.rs b/packages/libraries/sdk-rs/src/agent/mod.rs new file mode 100644 index 000000000..e52fa1ad0 --- /dev/null +++ b/packages/libraries/sdk-rs/src/agent/mod.rs @@ -0,0 +1,4 @@ +pub mod buffer; +pub mod builder; +pub mod usage_agent; +pub mod utils; diff --git a/packages/libraries/sdk-rs/src/agent.rs b/packages/libraries/sdk-rs/src/agent/usage_agent.rs similarity index 62% rename from packages/libraries/sdk-rs/src/agent.rs rename to packages/libraries/sdk-rs/src/agent/usage_agent.rs index 15031f63b..5cecea51c 100644 --- a/packages/libraries/sdk-rs/src/agent.rs +++ b/packages/libraries/sdk-rs/src/agent/usage_agent.rs @@ -1,16 +1,18 @@ -use super::graphql::OperationProcessor; +use async_dropper_simple::{AsyncDrop, AsyncDropper}; use graphql_parser::schema::Document; -use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use recloser::AsyncRecloser; +use reqwest_middleware::ClientWithMiddleware; use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, - sync::{Arc, Mutex}, + collections::{hash_map::Entry, HashMap}, + sync::Arc, time::Duration, }; use thiserror::Error; use tokio_util::sync::CancellationToken; +use crate::agent::{buffer::AddStatus, utils::OperationProcessor}; +use crate::agent::{buffer::Buffer, builder::UsageAgentBuilder}; + #[derive(Debug, Clone)] pub struct ExecutionReport { pub schema: Arc>, @@ -27,44 +29,16 @@ pub struct ExecutionReport { typify::import_types!(schema = "./usage-report-v2.schema.json"); -#[derive(Debug, Default)] -pub struct Buffer(Mutex>); - -impl Buffer { - fn new() -> Self { - Self(Mutex::new(VecDeque::new())) - } - - fn lock_buffer( - &self, - ) -> Result>, AgentError> { - let buffer: Result>, AgentError> = - self.0.lock().map_err(|e| AgentError::Lock(e.to_string())); - buffer - } - - pub fn push(&self, report: ExecutionReport) -> Result { - let mut buffer = self.lock_buffer()?; - buffer.push_back(report); - Ok(buffer.len()) - } - - pub fn drain(&self) -> Result, AgentError> { - let mut buffer = self.lock_buffer()?; - let reports: Vec = buffer.drain(..).collect(); - Ok(reports) - } -} -pub struct UsageAgent { - buffer_size: usize, - endpoint: String, - buffer: Buffer, - processor: OperationProcessor, - client: ClientWithMiddleware, - flush_interval: Duration, +pub struct UsageAgentInner { + pub(crate) endpoint: String, + pub(crate) buffer: Buffer, + pub(crate) processor: OperationProcessor, + pub(crate) client: ClientWithMiddleware, + pub(crate) flush_interval: Duration, + pub(crate) circuit_breaker: AsyncRecloser, } -fn non_empty_string(value: Option) -> Option { +pub fn non_empty_string(value: Option) -> Option { value.filter(|str| !str.is_empty()) } @@ -72,81 +46,45 @@ fn non_empty_string(value: Option) -> Option { pub enum AgentError { #[error("unable to acquire lock: {0}")] Lock(String), - #[error("unable to send report: token is missing")] + #[error("unable to send report: unauthorized")] Unauthorized, #[error("unable to send report: no access")] Forbidden, #[error("unable to send report: rate limited")] RateLimited, - #[error("invalid token provided: {0}")] - InvalidToken(String), + #[error("missing token")] + MissingToken, + #[error("your access token requires providing a 'target_id' option.")] + MissingTargetId, + #[error("using 'target_id' with legacy tokens is not supported")] + TargetIdWithLegacyToken, + #[error("invalid token provided")] + InvalidToken, + #[error("invalid target id provided: {0}, it should be either a slug like \"$organizationSlug/$projectSlug/$targetSlug\" or an UUID")] + InvalidTargetId(String), #[error("unable to instantiate the http client for reports sending: {0}")] HTTPClientCreationError(reqwest::Error), + #[error("unable to create circuit breaker: {0}")] + CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError), + #[error("rejected by the circuit breaker")] + CircuitBreakerRejected, #[error("unable to send report: {0}")] Unknown(String), } -impl UsageAgent { - #[allow(clippy::too_many_arguments)] - pub fn try_new( - token: &str, - endpoint: String, - target_id: Option, - buffer_size: usize, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - flush_interval: Duration, - user_agent: String, - ) -> Result, AgentError> { - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); +pub type UsageAgent = Arc>; - let mut default_headers = HeaderMap::new(); - - default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2")); - - let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token)) - .map_err(|_| AgentError::InvalidToken(token.to_string()))?; - - authorization_header.set_sensitive(true); - - default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header); - - default_headers.insert( - reqwest::header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(default_headers) - .build() - .map_err(AgentError::HTTPClientCreationError)?; - let client = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - let mut endpoint = endpoint; - - if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") { - if let Some(target_id) = target_id { - endpoint.push_str(&format!("/{}", target_id)); - } - } - - Ok(Arc::new(Self { - buffer_size, - endpoint, - buffer: Buffer::new(), - processor: OperationProcessor::new(), - client, - flush_interval, - })) +#[async_trait::async_trait] +pub trait UsageAgentExt { + fn builder() -> UsageAgentBuilder { + UsageAgentBuilder::default() } + async fn flush(&self) -> Result<(), AgentError>; + async fn start_flush_interval(&self, token: &CancellationToken); + async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>; +} +impl UsageAgentInner { fn produce_report(&self, reports: Vec) -> Result { let mut report = Report { size: 0, @@ -233,21 +171,21 @@ impl UsageAgent { Ok(report) } - pub async fn send_report(&self, report: Report) -> Result<(), AgentError> { + async fn send_report(&self, report: Report) -> Result<(), AgentError> { if report.size == 0 { return Ok(()); } - let report_body = - serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?; // Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure + let resp_fut = self.client.post(&self.endpoint).json(&report).send(); + let resp = self - .client - .post(&self.endpoint) - .header(reqwest::header::CONTENT_LENGTH, report_body.len()) - .body(report_body) - .send() + .circuit_breaker + .call(resp_fut) .await - .map_err(|e| AgentError::Unknown(e.to_string()))?; + .map_err(|e| match e { + recloser::Error::Inner(e) => AgentError::Unknown(e.to_string()), + recloser::Error::Rejected => AgentError::CircuitBreakerRejected, + })?; match resp.status() { reqwest::StatusCode::OK => Ok(()), @@ -262,67 +200,57 @@ impl UsageAgent { } } - pub async fn flush(&self) { - let execution_reports = match self.buffer.drain() { - Ok(res) => res, - Err(e) => { - tracing::error!("Unable to acquire lock for State in drain_reports: {}", e); - Vec::new() - } - }; - let size = execution_reports.len(); - - if size > 0 { - match self.produce_report(execution_reports) { - Ok(report) => match self.send_report(report).await { - Ok(_) => tracing::debug!("Reported {} operations", size), - Err(e) => tracing::error!("{}", e), - }, - Err(e) => tracing::error!("{}", e), - } + async fn handle_drained(&self, drained: Vec) -> Result<(), AgentError> { + if drained.is_empty() { + return Ok(()); } + let report = self.produce_report(drained)?; + self.send_report(report).await } - pub async fn start_flush_interval(&self, token: Option) { - let mut tokio_interval = tokio::time::interval(self.flush_interval); - match token { - Some(token) => loop { - tokio::select! { - _ = tokio_interval.tick() => { self.flush().await; } - _ = token.cancelled() => { println!("Shutting down."); return; } - } - }, - None => loop { - tokio_interval.tick().await; - self.flush().await; - }, - } + async fn flush(&self) -> Result<(), AgentError> { + let execution_reports = self.buffer.drain().await; + + self.handle_drained(execution_reports).await?; + + Ok(()) } } -pub trait UsageAgentExt { - fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>; - fn flush_if_full(&self, size: usize) -> Result<(), AgentError>; -} +#[async_trait::async_trait] +impl UsageAgentExt for UsageAgent { + async fn flush(&self) -> Result<(), AgentError> { + self.inner().flush().await + } -impl UsageAgentExt for Arc { - fn flush_if_full(&self, size: usize) -> Result<(), AgentError> { - if size >= self.buffer_size { - let cloned_self = self.clone(); - tokio::task::spawn(async move { - cloned_self.flush().await; - }); + async fn start_flush_interval(&self, token: &CancellationToken) { + loop { + tokio::time::sleep(self.inner().flush_interval).await; + if token.is_cancelled() { + println!("Shutting down."); + return; + } + self.flush() + .await + .unwrap_or_else(|e| tracing::error!("Failed to flush usage reports: {}", e)); + } + } + + async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> { + if let AddStatus::Full { drained } = self.inner().buffer.add(execution_report).await { + self.inner().handle_drained(drained).await?; } Ok(()) } +} - fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> { - let size = self.buffer.push(execution_report)?; - - self.flush_if_full(size)?; - - Ok(()) +#[async_trait::async_trait] +impl AsyncDrop for UsageAgentInner { + async fn async_drop(&mut self) { + if let Err(e) = self.flush().await { + tracing::error!("Failed to flush usage reports during drop: {}", e); + } } } @@ -333,14 +261,14 @@ mod tests { use graphql_parser::{parse_query, parse_schema}; use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; - use crate::agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt}; + use crate::agent::usage_agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt}; const CONTENT_TYPE_VALUE: &'static str = "application/json"; const GRAPHQL_CLIENT_NAME: &'static str = "Hive Client"; const GRAPHQL_CLIENT_VERSION: &'static str = "1.0.0"; - #[tokio::test] - async fn should_send_data_to_hive() { + #[tokio::test(flavor = "multi_thread")] + async fn should_send_data_to_hive() -> Result<(), Box> { let token = "Token"; let mut server = mockito::Server::new_async().await; @@ -349,13 +277,13 @@ mod tests { let timestamp = 1625247600; let duration = Duration::from_millis(20); - let user_agent = format!("hive-router-sdk-test"); + let user_agent = "hive-router-sdk-test"; let mock = server .mock("POST", "/200") .match_header(AUTHORIZATION, format!("Bearer {}", token).as_str()) .match_header(CONTENT_TYPE, CONTENT_TYPE_VALUE) - .match_header(USER_AGENT, user_agent.as_str()) + .match_header(USER_AGENT, user_agent) .match_header("X-Usage-API-Version", "2") .match_request(move |request| { let request_body = request.body().expect("Failed to extract body"); @@ -469,8 +397,7 @@ mod tests { CUSTOM } "#, - ) - .expect("Failed to parse schema"); + )?; let op: graphql_tools::static_graphql::query::Document = parse_query( r#" @@ -493,39 +420,34 @@ mod tests { type } "#, - ) - .expect("Failed to parse query"); + )?; - let usage_agent = UsageAgent::try_new( - token, - format!("{}/200", server_url), - None, - 10, - Duration::from_millis(500), - Duration::from_millis(500), - false, - Duration::from_millis(10), - user_agent, - ) - .expect("Failed to create UsageAgent"); + // Testing async drop + { + let usage_agent = UsageAgent::builder() + .token(token.into()) + .endpoint(format!("{}/200", server_url)) + .user_agent(user_agent.into()) + .build()?; - usage_agent - .add_report(ExecutionReport { - schema: Arc::new(schema), - operation_body: op.to_string(), - operation_name: Some("deleteProject".to_string()), - client_name: Some(GRAPHQL_CLIENT_NAME.to_string()), - client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()), - timestamp: timestamp.try_into().unwrap(), - duration, - ok: true, - errors: 0, - persisted_document_hash: None, - }) - .expect("Failed to add report"); - - usage_agent.flush().await; + usage_agent + .add_report(ExecutionReport { + schema: Arc::new(schema), + operation_body: op.to_string(), + operation_name: Some("deleteProject".to_string()), + client_name: Some(GRAPHQL_CLIENT_NAME.to_string()), + client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()), + timestamp, + duration, + ok: true, + errors: 0, + persisted_document_hash: None, + }) + .await?; + } mock.assert_async().await; + + Ok(()) } } diff --git a/packages/libraries/sdk-rs/src/graphql.rs b/packages/libraries/sdk-rs/src/agent/utils.rs similarity index 100% rename from packages/libraries/sdk-rs/src/graphql.rs rename to packages/libraries/sdk-rs/src/agent/utils.rs diff --git a/packages/libraries/sdk-rs/src/circuit_breaker.rs b/packages/libraries/sdk-rs/src/circuit_breaker.rs new file mode 100644 index 000000000..0dbd4529c --- /dev/null +++ b/packages/libraries/sdk-rs/src/circuit_breaker.rs @@ -0,0 +1,67 @@ +use std::time::Duration; + +use recloser::{AsyncRecloser, Recloser}; + +#[derive(Clone)] +pub struct CircuitBreakerBuilder { + error_threshold: f32, + volume_threshold: usize, + reset_timeout: Duration, +} + +impl Default for CircuitBreakerBuilder { + fn default() -> Self { + Self { + error_threshold: 0.5, + volume_threshold: 5, + reset_timeout: Duration::from_secs(30), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum CircuitBreakerError { + #[error("Invalid error threshold: {0}. It must be between 0.0 and 1.0")] + InvalidErrorThreshold(f32), +} + +impl CircuitBreakerBuilder { + /// Percentage after what the circuit breaker should kick in. + /// Default: .5 + pub fn error_threshold(mut self, percentage: f32) -> Self { + self.error_threshold = percentage; + self + } + /// Count of requests before starting evaluating. + /// Default: 5 + pub fn volume_threshold(mut self, threshold: usize) -> Self { + self.volume_threshold = threshold; + self + } + /// After what time the circuit breaker is attempting to retry sending requests in milliseconds. + /// Default: 30s + pub fn reset_timeout(mut self, timeout: Duration) -> Self { + self.reset_timeout = timeout; + self + } + + pub fn build_async(self) -> Result { + let recloser = self.build_sync()?; + Ok(AsyncRecloser::from(recloser)) + } + pub fn build_sync(self) -> Result { + let error_threshold = if self.error_threshold < 0.0 || self.error_threshold > 1.0 { + return Err(CircuitBreakerError::InvalidErrorThreshold( + self.error_threshold, + )); + } else { + self.error_threshold + }; + let recloser = Recloser::custom() + .error_rate(error_threshold) + .closed_len(self.volume_threshold) + .open_wait(self.reset_timeout) + .build(); + Ok(recloser) + } +} diff --git a/packages/libraries/sdk-rs/src/lib.rs b/packages/libraries/sdk-rs/src/lib.rs index ec6f97886..0201c9cc2 100644 --- a/packages/libraries/sdk-rs/src/lib.rs +++ b/packages/libraries/sdk-rs/src/lib.rs @@ -1,4 +1,4 @@ pub mod agent; -pub mod graphql; +pub mod circuit_breaker; pub mod persisted_documents; pub mod supergraph_fetcher; diff --git a/packages/libraries/sdk-rs/src/persisted_documents.rs b/packages/libraries/sdk-rs/src/persisted_documents.rs index 02ce66132..4a5aab95c 100644 --- a/packages/libraries/sdk-rs/src/persisted_documents.rs +++ b/packages/libraries/sdk-rs/src/persisted_documents.rs @@ -1,18 +1,22 @@ use std::time::Duration; +use crate::agent::usage_agent::non_empty_string; +use crate::circuit_breaker::CircuitBreakerBuilder; use moka::future::Cache; +use recloser::AsyncRecloser; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientWithMiddleware; -use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use reqwest_retry::RetryTransientMiddleware; +use retry_policies::policies::ExponentialBackoff; use tracing::{debug, info, warn}; #[derive(Debug)] pub struct PersistedDocumentsManager { - agent: ClientWithMiddleware, + client: ClientWithMiddleware, cache: Cache, - endpoint: String, + endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>, } #[derive(Debug, thiserror::Error)] @@ -31,6 +35,18 @@ pub enum PersistedDocumentsError { FailedToReadCDNResponse(reqwest::Error), #[error("No persisted document provided, or document id cannot be resolved.")] PersistedDocumentRequired, + #[error("Missing required configuration option: {0}")] + MissingConfigurationOption(String), + #[error("Invalid CDN key {0}")] + InvalidCDNKey(String), + #[error("Failed to create HTTP client: {0}")] + HTTPClientCreationError(reqwest::Error), + #[error("unable to create circuit breaker: {0}")] + CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError), + #[error("rejected by the circuit breaker")] + CircuitBreakerRejected, + #[error("unknown error")] + Unknown, } impl PersistedDocumentsError { @@ -51,47 +67,75 @@ impl PersistedDocumentsError { PersistedDocumentsError::PersistedDocumentRequired => { "PERSISTED_DOCUMENT_REQUIRED".into() } + PersistedDocumentsError::MissingConfigurationOption(_) => { + "MISSING_CONFIGURATION_OPTION".into() + } + PersistedDocumentsError::InvalidCDNKey(_) => "INVALID_CDN_KEY".into(), + PersistedDocumentsError::HTTPClientCreationError(_) => { + "HTTP_CLIENT_CREATION_ERROR".into() + } + PersistedDocumentsError::CircuitBreakerCreationError(_) => { + "CIRCUIT_BREAKER_CREATION_ERROR".into() + } + PersistedDocumentsError::CircuitBreakerRejected => "CIRCUIT_BREAKER_REJECTED".into(), + PersistedDocumentsError::Unknown => "UNKNOWN_ERROR".into(), } } } impl PersistedDocumentsManager { - #[allow(clippy::too_many_arguments)] - pub fn new( - key: String, - endpoint: String, - accept_invalid_certs: bool, - connect_timeout: Duration, - request_timeout: Duration, - retry_count: u32, - cache_size: u64, - user_agent: String, - ) -> Self { - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count); - - let mut default_headers = HeaderMap::new(); - default_headers.insert("X-Hive-CDN-Key", HeaderValue::from_str(&key).unwrap()); - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(default_headers) - .build() - .expect("Failed to create reqwest client"); - let agent = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - let cache = Cache::::new(cache_size); - - Self { - agent, - cache, - endpoint, - } + pub fn builder() -> PersistedDocumentsManagerBuilder { + PersistedDocumentsManagerBuilder::default() } + async fn resolve_from_endpoint( + &self, + endpoint: &str, + document_id: &str, + circuit_breaker: &AsyncRecloser, + ) -> Result { + let cdn_document_id = str::replace(document_id, "~", "/"); + let cdn_artifact_url = format!("{}/apps/{}", endpoint, cdn_document_id); + info!( + "Fetching document {} from CDN: {}", + document_id, cdn_artifact_url + ); + let response_fut = self.client.get(cdn_artifact_url).send(); + let response = circuit_breaker + .call(response_fut) + .await + .map_err(|e| match e { + recloser::Error::Inner(e) => PersistedDocumentsError::FailedToFetchFromCDN(e), + recloser::Error::Rejected => PersistedDocumentsError::CircuitBreakerRejected, + })?; + + if response.status().is_success() { + let document = response + .text() + .await + .map_err(PersistedDocumentsError::FailedToReadCDNResponse)?; + debug!( + "Document fetched from CDN: {}, storing in local cache", + document + ); + self.cache + .insert(document_id.into(), document.clone()) + .await; + + return Ok(document); + } + + warn!( + "Document fetch from CDN failed: HTTP {}, Body: {:?}", + response.status(), + response + .text() + .await + .unwrap_or_else(|_| "Unavailable".to_string()) + ); + + Err(PersistedDocumentsError::DocumentNotFound) + } /// Resolves the document from the cache, or from the CDN pub async fn resolve_document( &self, @@ -110,50 +154,173 @@ impl PersistedDocumentsManager { "Document {} not found in cache. Fetching from CDN", document_id ); - let cdn_document_id = str::replace(document_id, "~", "/"); - let cdn_artifact_url = format!("{}/apps/{}", &self.endpoint, cdn_document_id); - info!( - "Fetching document {} from CDN: {}", - document_id, cdn_artifact_url - ); - let cdn_response = self.agent.get(cdn_artifact_url).send().await; - - match cdn_response { - Ok(response) => { - if response.status().is_success() { - let document = response - .text() - .await - .map_err(PersistedDocumentsError::FailedToReadCDNResponse)?; - debug!( - "Document fetched from CDN: {}, storing in local cache", - document - ); - self.cache - .insert(document_id.into(), document.clone()) - .await; - - return Ok(document); + let mut last_error: Option = None; + for (endpoint, circuit_breaker) in &self.endpoints_with_circuit_breakers { + let result = self + .resolve_from_endpoint(endpoint, document_id, circuit_breaker) + .await; + match result { + Ok(document) => return Ok(document), + Err(e) => { + last_error = Some(e); } - - warn!( - "Document fetch from CDN failed: HTTP {}, Body: {:?}", - response.status(), - response - .text() - .await - .unwrap_or_else(|_| "Unavailable".to_string()) - ); - - Err(PersistedDocumentsError::DocumentNotFound) - } - Err(e) => { - warn!("Failed to fetch document from CDN: {:?}", e); - - Err(PersistedDocumentsError::FailedToFetchFromCDN(e)) } } + match last_error { + Some(e) => Err(e), + None => Err(PersistedDocumentsError::Unknown), + } } } } } + +pub struct PersistedDocumentsManagerBuilder { + key: Option, + endpoints: Vec, + accept_invalid_certs: bool, + connect_timeout: Duration, + request_timeout: Duration, + retry_policy: ExponentialBackoff, + cache_size: u64, + user_agent: Option, + circuit_breaker: CircuitBreakerBuilder, +} + +impl Default for PersistedDocumentsManagerBuilder { + fn default() -> Self { + Self { + key: None, + endpoints: vec![], + accept_invalid_certs: false, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(15), + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + cache_size: 10_000, + user_agent: None, + circuit_breaker: CircuitBreakerBuilder::default(), + } + } +} + +impl PersistedDocumentsManagerBuilder { + /// The CDN Access Token with from the Hive Console target. + pub fn key(mut self, key: String) -> Self { + self.key = non_empty_string(Some(key)); + self + } + + /// The CDN endpoint from Hive Console target. + pub fn add_endpoint(mut self, endpoint: String) -> Self { + if let Some(endpoint) = non_empty_string(Some(endpoint)) { + self.endpoints.push(endpoint); + } + self + } + + /// Accept invalid SSL certificates + /// default: false + pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self { + self.accept_invalid_certs = accept_invalid_certs; + self + } + + /// Connection timeout for the Hive Console CDN requests. + /// Default: 5 seconds + pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + + /// Request timeout for the Hive Console CDN requests. + /// Default: 15 seconds + pub fn request_timeout(mut self, request_timeout: Duration) -> Self { + self.request_timeout = request_timeout; + self + } + + /// Retry policy for fetching persisted documents + /// Default: ExponentialBackoff with max 3 retries + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + + /// Maximum number of retries for fetching persisted documents + /// Default: ExponentialBackoff with max 3 retries + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + + /// Size of the in-memory cache for persisted documents + /// Default: 10,000 entries + pub fn cache_size(mut self, cache_size: u64) -> Self { + self.cache_size = cache_size; + self + } + + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + self.user_agent = non_empty_string(Some(user_agent)); + self + } + + pub fn build(self) -> Result { + let mut default_headers = HeaderMap::new(); + let key = match self.key { + Some(key) => key, + None => { + return Err(PersistedDocumentsError::MissingConfigurationOption( + "key".to_string(), + )); + } + }; + default_headers.insert( + "X-Hive-CDN-Key", + HeaderValue::from_str(&key) + .map_err(|e| PersistedDocumentsError::InvalidCDNKey(e.to_string()))?, + ); + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(default_headers); + + if let Some(user_agent) = self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); + } + + let reqwest_agent = reqwest_agent + .build() + .map_err(PersistedDocumentsError::HTTPClientCreationError)?; + let client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + let cache = Cache::::new(self.cache_size); + + if self.endpoints.is_empty() { + return Err(PersistedDocumentsError::MissingConfigurationOption( + "endpoints".to_string(), + )); + } + + Ok(PersistedDocumentsManager { + client, + cache, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(move |endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .build_async() + .map_err(PersistedDocumentsError::CircuitBreakerCreationError)?; + Ok((endpoint, circuit_breaker)) + }) + .collect::, PersistedDocumentsError>>()?, + }) + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher.rs deleted file mode 100644 index 98c2540ee..000000000 --- a/packages/libraries/sdk-rs/src/supergraph_fetcher.rs +++ /dev/null @@ -1,260 +0,0 @@ -use std::fmt::Display; -use std::sync::RwLock; -use std::time::Duration; -use std::time::SystemTime; - -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; -use reqwest::header::InvalidHeaderValue; -use reqwest::header::IF_NONE_MATCH; -use reqwest_middleware::ClientBuilder; -use reqwest_middleware::ClientWithMiddleware; -use reqwest_retry::policies::ExponentialBackoff; -use reqwest_retry::RetryDecision; -use reqwest_retry::RetryPolicy; -use reqwest_retry::RetryTransientMiddleware; - -#[derive(Debug)] -pub struct SupergraphFetcher { - client: SupergraphFetcherAsyncOrSyncClient, - endpoint: String, - etag: RwLock>, - state: std::marker::PhantomData, -} - -#[derive(Debug)] -pub struct SupergraphFetcherAsyncState; -#[derive(Debug)] -pub struct SupergraphFetcherSyncState; - -#[derive(Debug)] -enum SupergraphFetcherAsyncOrSyncClient { - Async { - reqwest_client: ClientWithMiddleware, - }, - Sync { - reqwest_client: reqwest::blocking::Client, - retry_policy: ExponentialBackoff, - }, -} - -pub enum SupergraphFetcherError { - FetcherCreationError(reqwest::Error), - NetworkError(reqwest_middleware::Error), - NetworkResponseError(reqwest::Error), - Lock(String), - InvalidKey(InvalidHeaderValue), -} - -impl Display for SupergraphFetcherError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SupergraphFetcherError::FetcherCreationError(e) => { - write!(f, "Creating fetcher failed: {}", e) - } - SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e), - SupergraphFetcherError::NetworkResponseError(e) => { - write!(f, "Network response error: {}", e) - } - SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e), - SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e), - } - } -} - -fn prepare_client_config( - mut endpoint: String, - key: &str, - retry_count: u32, -) -> Result<(String, HeaderMap, ExponentialBackoff), SupergraphFetcherError> { - if !endpoint.ends_with("/supergraph") { - if endpoint.ends_with("/") { - endpoint.push_str("supergraph"); - } else { - endpoint.push_str("/supergraph"); - } - } - - let mut headers = HeaderMap::new(); - let mut cdn_key_header = - HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?; - cdn_key_header.set_sensitive(true); - headers.insert("X-Hive-CDN-Key", cdn_key_header); - - let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count); - - Ok((endpoint, headers, retry_policy)) -} - -impl SupergraphFetcher { - #[allow(clippy::too_many_arguments)] - pub fn try_new_sync( - endpoint: String, - key: &str, - user_agent: String, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - retry_count: u32, - ) -> Result { - let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?; - - Ok(Self { - client: SupergraphFetcherAsyncOrSyncClient::Sync { - reqwest_client: reqwest::blocking::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .user_agent(user_agent) - .default_headers(headers) - .build() - .map_err(SupergraphFetcherError::FetcherCreationError)?, - retry_policy, - }, - endpoint, - etag: RwLock::new(None), - state: std::marker::PhantomData, - }) - } - - pub fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { - let request_start_time = SystemTime::now(); - // Implementing retry logic for sync client - let mut n_past_retries = 0; - let (reqwest_client, retry_policy) = match &self.client { - SupergraphFetcherAsyncOrSyncClient::Sync { - reqwest_client, - retry_policy, - } => (reqwest_client, retry_policy), - _ => unreachable!(), - }; - let resp = loop { - let mut req = reqwest_client.get(&self.endpoint); - let etag = self.get_latest_etag()?; - if let Some(etag) = etag { - req = req.header(IF_NONE_MATCH, etag); - } - let response = req.send(); - - match response { - Ok(resp) => break resp, - Err(e) => match retry_policy.should_retry(request_start_time, n_past_retries) { - RetryDecision::DoNotRetry => { - return Err(SupergraphFetcherError::NetworkError( - reqwest_middleware::Error::Reqwest(e), - )); - } - RetryDecision::Retry { execute_after } => { - n_past_retries += 1; - if let Ok(duration) = execute_after.elapsed() { - std::thread::sleep(duration); - } - } - }, - } - }; - - if resp.status().as_u16() == 304 { - return Ok(None); - } - - let etag = resp.headers().get("etag"); - self.update_latest_etag(etag)?; - - let text = resp - .text() - .map_err(SupergraphFetcherError::NetworkResponseError)?; - - Ok(Some(text)) - } -} - -impl SupergraphFetcher { - #[allow(clippy::too_many_arguments)] - pub fn try_new_async( - endpoint: String, - key: &str, - user_agent: String, - connect_timeout: Duration, - request_timeout: Duration, - accept_invalid_certs: bool, - retry_count: u32, - ) -> Result { - let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?; - - let reqwest_agent = reqwest::Client::builder() - .danger_accept_invalid_certs(accept_invalid_certs) - .connect_timeout(connect_timeout) - .timeout(request_timeout) - .default_headers(headers) - .user_agent(user_agent) - .build() - .map_err(SupergraphFetcherError::FetcherCreationError)?; - let reqwest_client = ClientBuilder::new(reqwest_agent) - .with(RetryTransientMiddleware::new_with_policy(retry_policy)) - .build(); - - Ok(Self { - client: SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client }, - endpoint, - etag: RwLock::new(None), - state: std::marker::PhantomData, - }) - } - pub async fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { - let reqwest_client = match &self.client { - SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client } => reqwest_client, - _ => unreachable!(), - }; - let mut req = reqwest_client.get(&self.endpoint); - let etag = self.get_latest_etag()?; - if let Some(etag) = etag { - req = req.header(IF_NONE_MATCH, etag); - } - - let resp = req - .send() - .await - .map_err(SupergraphFetcherError::NetworkError)?; - - if resp.status().as_u16() == 304 { - return Ok(None); - } - - let etag = resp.headers().get("etag"); - self.update_latest_etag(etag)?; - - let text = resp - .text() - .await - .map_err(SupergraphFetcherError::NetworkResponseError)?; - - Ok(Some(text)) - } -} - -impl SupergraphFetcher { - fn get_latest_etag(&self) -> Result, SupergraphFetcherError> { - let guard: std::sync::RwLockReadGuard<'_, Option> = - self.etag.try_read().map_err(|e| { - SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e)) - })?; - - Ok(guard.clone()) - } - - fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> { - let mut guard: std::sync::RwLockWriteGuard<'_, Option> = - self.etag.try_write().map_err(|e| { - SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e)) - })?; - - if let Some(etag_value) = etag { - *guard = Some(etag_value.clone()); - } else { - *guard = None; - } - - Ok(()) - } -} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs new file mode 100644 index 000000000..b0bb1eddb --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs @@ -0,0 +1,141 @@ +use futures_util::TryFutureExt; +use recloser::AsyncRecloser; +use reqwest::header::{HeaderValue, IF_NONE_MATCH}; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; +use reqwest_retry::RetryTransientMiddleware; +use tokio::sync::RwLock; + +use crate::supergraph_fetcher::{ + builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError, +}; + +#[derive(Debug)] +pub struct SupergraphFetcherAsyncState { + endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>, + reqwest_client: ClientWithMiddleware, +} + +impl SupergraphFetcher { + pub async fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { + let mut last_error: Option = None; + let mut last_resp = None; + for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers { + let mut req = self.state.reqwest_client.get(endpoint); + let etag = self.get_latest_etag().await; + if let Some(etag) = etag { + req = req.header(IF_NONE_MATCH, etag); + } + let resp_fut = async { + let mut resp = req.send().await.map_err(SupergraphFetcherError::Network); + // Server errors (5xx) are considered errors + if let Ok(ok_res) = resp { + resp = if ok_res.status().is_server_error() { + return Err(SupergraphFetcherError::Network( + reqwest_middleware::Error::Middleware(anyhow::anyhow!( + "Server error: {}", + ok_res.status() + )), + )); + } else { + Ok(ok_res) + } + } + resp + }; + let resp = circuit_breaker + .call(resp_fut) + // Map recloser errors to SupergraphFetcherError + .map_err(|e| match e { + recloser::Error::Inner(e) => e, + recloser::Error::Rejected => SupergraphFetcherError::RejectedByCircuitBreaker, + }) + .await; + match resp { + Err(err) => { + last_error = Some(err); + continue; + } + Ok(resp) => { + last_resp = Some(resp); + break; + } + } + } + + if let Some(last_resp) = last_resp { + let etag = last_resp.headers().get("etag"); + self.update_latest_etag(etag).await; + let text = last_resp + .text() + .await + .map_err(SupergraphFetcherError::ResponseParse)?; + Ok(Some(text)) + } else if let Some(error) = last_error { + Err(error) + } else { + Ok(None) + } + } + async fn get_latest_etag(&self) -> Option { + let guard = self.etag.read().await; + + guard.clone() + } + async fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> () { + let mut guard = self.etag.write().await; + + if let Some(etag_value) = etag { + *guard = Some(etag_value.clone()); + } else { + *guard = None; + } + } +} + +impl SupergraphFetcherBuilder { + /// Builds an asynchronous SupergraphFetcher + pub fn build_async( + self, + ) -> Result, SupergraphFetcherError> { + self.validate_endpoints()?; + + let headers = self.prepare_headers()?; + + let mut reqwest_agent = reqwest::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(headers); + + if let Some(user_agent) = self.user_agent { + reqwest_agent = reqwest_agent.user_agent(user_agent); + } + + let reqwest_agent = reqwest_agent + .build() + .map_err(SupergraphFetcherError::HTTPClientCreation)?; + let reqwest_client = ClientBuilder::new(reqwest_agent) + .with(RetryTransientMiddleware::new_with_policy(self.retry_policy)) + .build(); + + Ok(SupergraphFetcher { + state: SupergraphFetcherAsyncState { + reqwest_client, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(|endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .unwrap_or_default() + .build_async() + .map_err(SupergraphFetcherError::CircuitBreakerCreation); + circuit_breaker.map(|cb| (endpoint, cb)) + }) + .collect::, _>>()?, + }, + etag: RwLock::new(None), + }) + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs new file mode 100644 index 000000000..adddc0112 --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs @@ -0,0 +1,135 @@ +use std::time::Duration; + +use reqwest::header::{HeaderMap, HeaderValue}; +use retry_policies::policies::ExponentialBackoff; + +use crate::{ + agent::usage_agent::non_empty_string, circuit_breaker::CircuitBreakerBuilder, + supergraph_fetcher::SupergraphFetcherError, +}; + +pub struct SupergraphFetcherBuilder { + pub(crate) endpoints: Vec, + pub(crate) key: Option, + pub(crate) user_agent: Option, + pub(crate) connect_timeout: Duration, + pub(crate) request_timeout: Duration, + pub(crate) accept_invalid_certs: bool, + pub(crate) retry_policy: ExponentialBackoff, + pub(crate) circuit_breaker: Option, +} + +impl Default for SupergraphFetcherBuilder { + fn default() -> Self { + Self { + endpoints: vec![], + key: None, + user_agent: None, + connect_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(60), + accept_invalid_certs: false, + retry_policy: ExponentialBackoff::builder().build_with_max_retries(3), + circuit_breaker: None, + } + } +} + +impl SupergraphFetcherBuilder { + pub fn new() -> Self { + Self::default() + } + + /// The CDN endpoint from Hive Console target. + pub fn add_endpoint(mut self, endpoint: String) -> Self { + if let Some(mut endpoint) = non_empty_string(Some(endpoint)) { + if !endpoint.ends_with("/supergraph") { + if endpoint.ends_with("/") { + endpoint.push_str("supergraph"); + } else { + endpoint.push_str("/supergraph"); + } + } + self.endpoints.push(endpoint); + } + self + } + + /// The CDN Access Token with from the Hive Console target. + pub fn key(mut self, key: String) -> Self { + self.key = Some(key); + self + } + + /// User-Agent header to be sent with each request + pub fn user_agent(mut self, user_agent: String) -> Self { + self.user_agent = Some(user_agent); + self + } + + /// Connection timeout for the Hive Console CDN requests. + /// Default: 5 seconds + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + /// Request timeout for the Hive Console CDN requests. + /// Default: 60 seconds + pub fn request_timeout(mut self, timeout: Duration) -> Self { + self.request_timeout = timeout; + self + } + + pub fn accept_invalid_certs(mut self, accept: bool) -> Self { + self.accept_invalid_certs = accept; + self + } + + /// Policy for retrying failed requests. + /// + /// By default, an exponential backoff retry policy is used, with 10 attempts. + pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self { + self.retry_policy = retry_policy; + self + } + + /// Maximum number of retries for failed requests. + /// + /// By default, an exponential backoff retry policy is used, with 10 attempts. + pub fn max_retries(mut self, max_retries: u32) -> Self { + self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries); + self + } + + pub fn circuit_breaker(&mut self, builder: CircuitBreakerBuilder) -> &mut Self { + self.circuit_breaker = Some(builder); + self + } + + pub(crate) fn validate_endpoints(&self) -> Result<(), SupergraphFetcherError> { + if self.endpoints.is_empty() { + return Err(SupergraphFetcherError::MissingConfigurationOption( + "endpoint".to_string(), + )); + } + Ok(()) + } + + pub(crate) fn prepare_headers(&self) -> Result { + let key = match &self.key { + Some(key) => key, + None => { + return Err(SupergraphFetcherError::MissingConfigurationOption( + "key".to_string(), + )) + } + }; + let mut headers = HeaderMap::new(); + let mut cdn_key_header = + HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?; + cdn_key_header.set_sensitive(true); + headers.insert("X-Hive-CDN-Key", cdn_key_header); + + Ok(headers) + } +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs new file mode 100644 index 000000000..2441ea371 --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs @@ -0,0 +1,51 @@ +use tokio::sync::RwLock; +use tokio::sync::TryLockError; + +use crate::circuit_breaker::CircuitBreakerError; +use crate::supergraph_fetcher::async_::SupergraphFetcherAsyncState; +use reqwest::header::HeaderValue; +use reqwest::header::InvalidHeaderValue; + +pub mod async_; +pub mod builder; +pub mod sync; + +#[derive(Debug)] +pub struct SupergraphFetcher { + state: State, + etag: RwLock>, +} + +// Doesn't matter which one we implement this for, both have the same builder +impl SupergraphFetcher { + pub fn builder() -> builder::SupergraphFetcherBuilder { + builder::SupergraphFetcherBuilder::default() + } +} + +pub enum LockErrorType { + Read, + Write, +} + +#[derive(Debug, thiserror::Error)] +pub enum SupergraphFetcherError { + #[error("Creating HTTP Client failed: {0}")] + HTTPClientCreation(reqwest::Error), + #[error("Network error: {0}")] + Network(reqwest_middleware::Error), + #[error("Parsing response failed: {0}")] + ResponseParse(reqwest::Error), + #[error("Reading the etag record failed: {0:?}")] + ETagRead(TryLockError), + #[error("Updating the etag record failed: {0:?}")] + ETagWrite(TryLockError), + #[error("Invalid CDN key: {0}")] + InvalidKey(InvalidHeaderValue), + #[error("Missing configuration option: {0}")] + MissingConfigurationOption(String), + #[error("Request rejected by circuit breaker")] + RejectedByCircuitBreaker, + #[error("Creating circuit breaker failed: {0}")] + CircuitBreakerCreation(CircuitBreakerError), +} diff --git a/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs b/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs new file mode 100644 index 000000000..f85aa58fe --- /dev/null +++ b/packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs @@ -0,0 +1,191 @@ +use std::time::SystemTime; + +use recloser::Recloser; +use reqwest::header::{HeaderValue, IF_NONE_MATCH}; +use reqwest_retry::{RetryDecision, RetryPolicy}; +use retry_policies::policies::ExponentialBackoff; +use tokio::sync::RwLock; + +use crate::supergraph_fetcher::{ + builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError, +}; + +#[derive(Debug)] +pub struct SupergraphFetcherSyncState { + endpoints_with_circuit_breakers: Vec<(String, Recloser)>, + reqwest_client: reqwest::blocking::Client, + retry_policy: ExponentialBackoff, +} + +impl SupergraphFetcher { + pub fn fetch_supergraph(&self) -> Result, SupergraphFetcherError> { + let mut last_error: Option = None; + let mut last_resp = None; + for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers { + let resp = { + circuit_breaker + .call(|| { + let request_start_time = SystemTime::now(); + // Implementing retry logic for sync client + let mut n_past_retries = 0; + loop { + let mut req = self.state.reqwest_client.get(endpoint); + let etag = self.get_latest_etag()?; + if let Some(etag) = etag { + req = req.header(IF_NONE_MATCH, etag); + } + let mut response = req.send().map_err(|err| { + SupergraphFetcherError::Network(reqwest_middleware::Error::Reqwest( + err, + )) + }); + + // Server errors (5xx) are considered retryable + if let Ok(ok_res) = response { + response = if ok_res.status().is_server_error() { + Err(SupergraphFetcherError::Network( + reqwest_middleware::Error::Middleware(anyhow::anyhow!( + "Server error: {}", + ok_res.status() + )), + )) + } else { + Ok(ok_res) + } + } + + match response { + Ok(resp) => break Ok(resp), + Err(e) => { + match self + .state + .retry_policy + .should_retry(request_start_time, n_past_retries) + { + RetryDecision::DoNotRetry => { + return Err(e); + } + RetryDecision::Retry { execute_after } => { + n_past_retries += 1; + match execute_after.elapsed() { + Ok(duration) => { + std::thread::sleep(duration); + } + Err(err) => { + tracing::error!( + "Error determining sleep duration for retry: {}", + err + ); + // If elapsed time cannot be determined, do not wait + return Err(e); + } + } + } + } + } + } + } + }) + // Map recloser errors to SupergraphFetcherError + .map_err(|e| match e { + recloser::Error::Inner(e) => e, + recloser::Error::Rejected => { + SupergraphFetcherError::RejectedByCircuitBreaker + } + }) + }; + match resp { + Err(e) => { + last_error = Some(e); + continue; + } + Ok(resp) => { + last_resp = Some(resp); + break; + } + } + } + + if let Some(last_resp) = last_resp { + if last_resp.status().as_u16() == 304 { + return Ok(None); + } + self.update_latest_etag(last_resp.headers().get("etag"))?; + let text = last_resp + .text() + .map_err(SupergraphFetcherError::ResponseParse)?; + Ok(Some(text)) + } else if let Some(error) = last_error { + Err(error) + } else { + Ok(None) + } + } + fn get_latest_etag(&self) -> Result, SupergraphFetcherError> { + let guard = self + .etag + .try_read() + .map_err(SupergraphFetcherError::ETagRead)?; + + Ok(guard.clone()) + } + fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> { + let mut guard = self + .etag + .try_write() + .map_err(SupergraphFetcherError::ETagWrite)?; + + if let Some(etag_value) = etag { + *guard = Some(etag_value.clone()); + } else { + *guard = None; + } + + Ok(()) + } +} + +impl SupergraphFetcherBuilder { + /// Builds a synchronous SupergraphFetcher + pub fn build_sync( + self, + ) -> Result, SupergraphFetcherError> { + self.validate_endpoints()?; + let headers = self.prepare_headers()?; + + let mut reqwest_client = reqwest::blocking::Client::builder() + .danger_accept_invalid_certs(self.accept_invalid_certs) + .connect_timeout(self.connect_timeout) + .timeout(self.request_timeout) + .default_headers(headers); + + if let Some(user_agent) = &self.user_agent { + reqwest_client = reqwest_client.user_agent(user_agent); + } + + let reqwest_client = reqwest_client + .build() + .map_err(SupergraphFetcherError::HTTPClientCreation)?; + let fetcher = SupergraphFetcher { + state: SupergraphFetcherSyncState { + reqwest_client, + retry_policy: self.retry_policy, + endpoints_with_circuit_breakers: self + .endpoints + .into_iter() + .map(|endpoint| { + let circuit_breaker = self + .circuit_breaker + .clone() + .unwrap_or_default() + .build_sync() + .map_err(SupergraphFetcherError::CircuitBreakerCreation); + circuit_breaker.map(|cb| (endpoint, cb)) + }) + .collect::, _>>()?, + }, + etag: RwLock::new(None), + }; + Ok(fetcher) + } +}