feat(sdk-rs): builder pattern, multiple endpoints & circuit breaker (#7379)

This commit is contained in:
Arda TANRIKULU 2026-01-12 04:30:53 -05:00 committed by GitHub
parent b90f215213
commit b13446109d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1619 additions and 750 deletions

View file

@ -0,0 +1,56 @@
---
'hive-console-sdk-rs': minor
---
Breaking Changes to avoid future breaking changes;
Switch to [Builder](https://rust-unofficial.github.io/patterns/patterns/creational/builder.html) pattern for `SupergraphFetcher`, `PersistedDocumentsManager` and `UsageAgent` structs.
No more `try_new` or `try_new_async` or `try_new_sync` functions, instead use `SupergraphFetcherBuilder`, `PersistedDocumentsManagerBuilder` and `UsageAgentBuilder` structs to create instances.
Benefits;
- No need to provide all parameters at once when creating an instance even for default values.
Example;
```rust
// Before
let fetcher = SupergraphFetcher::try_new_async(
"SOME_ENDPOINT", // endpoint
"SOME_KEY",
"MyUserAgent/1.0".to_string(),
Duration::from_secs(5), // connect_timeout
Duration::from_secs(10), // request_timeout
false, // accept_invalid_certs
3, // retry_count
)?;
// After
// No need to provide all parameters at once, can use default values
let fetcher = SupergraphFetcherBuilder::new()
.endpoint("SOME_ENDPOINT".to_string())
.key("SOME_KEY".to_string())
.build_async()?;
```
- Easier to add new configuration options in the future without breaking existing code.
Example;
```rust
let fetcher = SupergraphFetcher::try_new_async(
"SOME_ENDPOINT", // endpoint
"SOME_KEY",
"MyUserAgent/1.0".to_string(),
Duration::from_secs(5), // connect_timeout
Duration::from_secs(10), // request_timeout
false, // accept_invalid_certs
3, // retry_count
circuit_breaker_config, // Breaking Change -> new parameter added
)?;
let fetcher = SupergraphFetcherBuilder::new()
.endpoint("SOME_ENDPOINT".to_string())
.key("SOME_KEY".to_string())
.build_async()?; // No breaking change, circuit_breaker_config can be added later if needed
```

View file

@ -0,0 +1,20 @@
---
'hive-console-sdk-rs': patch
---
Circuit Breaker Implementation and Multiple Endpoints Support
Implementation of Circuit Breakers in Hive Console Rust SDK, you can learn more [here](https://the-guild.dev/graphql/hive/product-updates/2025-12-04-cdn-mirror-and-circuit-breaker)
Breaking Changes:
Now `endpoint` configuration accepts multiple endpoints as an array for `SupergraphFetcherBuilder` and `PersistedDocumentsManager`.
```diff
SupergraphFetcherBuilder::default()
- .endpoint(endpoint)
+ .add_endpoint(endpoint1)
+ .add_endpoint(endpoint2)
```
This change requires updating the configuration structure to accommodate multiple endpoints.

View file

@ -0,0 +1,17 @@
---
'hive-apollo-router-plugin': major
---
- Multiple endpoints support for `HiveRegistry` and `PersistedOperationsPlugin`
Breaking Changes:
- Now there is no `endpoint` field in the configuration, it has been replaced with `endpoints`, which is an array of strings. You are not affected if you use environment variables to set the endpoint.
```diff
HiveRegistry::new(
Some(
HiveRegistryConfig {
- endpoint: String::from("CDN_ENDPOINT"),
+ endpoints: vec![String::from("CDN_ENDPOINT1"), String::from("CDN_ENDPOINT2")],
)
)

View file

@ -23,6 +23,7 @@ jobs:
integration:
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
# Divide integration tests into 3 shards, to run them in parallel.
shardIndex: [1, 2, 3, 'apollo-router']

View file

@ -122,9 +122,9 @@ dependencies = [
[[package]]
name = "apollo-federation"
version = "2.10.0"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "034b556462798d1f076d865791410c088fd549c920f2a5c0d83a0611141559b4"
checksum = "6b26138e298ecff0fb0b972175a93174d7d87d74cdc646af6a7932dc435d86c8"
dependencies = [
"apollo-compiler",
"derive_more",
@ -171,9 +171,9 @@ dependencies = [
[[package]]
name = "apollo-router"
version = "2.10.0"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db37c08af3b5844565e98567e92e68ae69ecd6bb4d3e674bb2a2f27d865d3680"
checksum = "4d5b4da5efb2fb56a7076ded5912eb829005d1c3c2fec35359855f689db3fc57"
dependencies = [
"addr2line",
"ahash",
@ -292,6 +292,7 @@ dependencies = [
"similar",
"static_assertions",
"strum",
"strum_macros",
"sys-info",
"sysinfo",
"thiserror 2.0.17",
@ -436,6 +437,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "async-dropper-simple"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7c4748dfe8cd3d625ec68fc424fa80c134319881185866f9e173af9e5d8add8"
dependencies = [
"async-scoped",
"async-trait",
"futures",
"rustc_version",
"tokio",
]
[[package]]
name = "async-executor"
version = "1.13.3"
@ -521,6 +535,17 @@ dependencies = [
"rustix",
]
[[package]]
name = "async-scoped"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740"
dependencies = [
"futures",
"pin-project",
"tokio",
]
[[package]]
name = "async-signal"
version = "0.2.13"
@ -2596,6 +2621,7 @@ dependencies = [
"serde_json",
"sha2",
"tokio",
"tokio-util",
"tower 0.5.2",
"tracing",
]
@ -2605,17 +2631,24 @@ name = "hive-console-sdk"
version = "0.2.3"
dependencies = [
"anyhow",
"async-dropper-simple",
"async-trait",
"axum-core 0.5.5",
"futures-util",
"graphql-parser",
"graphql-tools",
"lazy_static",
"md5",
"mockito",
"moka",
"once_cell",
"recloser",
"regex-automata",
"regress",
"reqwest",
"reqwest-middleware",
"reqwest-retry",
"retry-policies",
"serde",
"serde_json",
"sha2",
@ -3531,9 +3564,9 @@ dependencies = [
[[package]]
name = "mockall"
version = "0.14.0"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f58d964098a5f9c6b63d0798e5372fd04708193510a7af313c22e9f29b7b620b"
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
dependencies = [
"cfg-if",
"downcast",
@ -3545,9 +3578,9 @@ dependencies = [
[[package]]
name = "mockall_derive"
version = "0.14.0"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca41ce716dda6a9be188b385aa78ee5260fc25cd3802cb2a8afdc6afbe6b6dbf"
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
dependencies = [
"cfg-if",
"proc-macro2",
@ -4589,6 +4622,16 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "recloser"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ac0d06281c3556fea72cef9e5372d9ac172335be0d71c3b4f3db900483e0eb"
dependencies = [
"crossbeam-epoch",
"pin-project",
]
[[package]]
name = "redis-protocol"
version = "6.0.0"
@ -5533,9 +5576,6 @@ name = "strum"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"

View file

@ -2,14 +2,13 @@ import { existsSync, rmSync, writeFileSync } from 'node:fs';
import { createServer } from 'node:http';
import { tmpdir } from 'node:os';
import { join } from 'node:path';
import { MaybePromise } from 'slonik/dist/src/types';
import { ProjectType } from 'testkit/gql/graphql';
import { initSeed } from 'testkit/seed';
import { getServiceHost } from 'testkit/utils';
import { execa } from '@esm2cjs/execa';
describe('Apollo Router Integration', () => {
const getBaseEndpoint = () =>
getServiceHost('server', 8082).then(v => `http://${v}/artifacts/v1/`);
const getAvailablePort = () =>
new Promise<number>(resolve => {
const server = createServer();
@ -18,12 +17,19 @@ describe('Apollo Router Integration', () => {
if (address && typeof address === 'object') {
const port = address.port;
server.close(() => resolve(port));
} else {
throw new Error('Could not get available port');
}
});
});
function defer(deferFn: () => MaybePromise<void>) {
return {
async [Symbol.asyncDispose]() {
return deferFn();
},
};
}
it('fetches the supergraph and sends usage reports', async () => {
const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`);
const endpointBaseUrl = await getBaseEndpoint();
const { createOrg } = await initSeed().createOwner();
const { createProject } = await createOrg();
const { createTargetAccessToken, createCdnAccess, target, waitForOperationsCollected } =
@ -50,6 +56,7 @@ describe('Apollo Router Integration', () => {
.then(r => r.expectNoGraphQLErrors());
expect(publishSchemaResult.schemaPublish.__typename).toBe('SchemaPublishSuccess');
const cdnAccessResult = await createCdnAccess();
const usageAddress = await getServiceHost('usage', 8081);
@ -60,6 +67,7 @@ describe('Apollo Router Integration', () => {
`Apollo Router binary not found at path: ${routerBinPath}, make sure to build it first with 'cargo build'`,
);
}
const routerPort = await getAvailablePort();
const routerConfigContent = `
supergraph:
@ -67,17 +75,22 @@ supergraph:
plugins:
hive.usage: {}
`.trim();
const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`);
writeFileSync(routerConfigPath, routerConfigContent, 'utf-8');
const cdnEndpoint = await getServiceHost('server', 8082).then(
v => `http://${v}/artifacts/v1/${target.id}`,
);
const routerProc = execa(routerBinPath, ['--dev', '--config', routerConfigPath], {
all: true,
env: {
HIVE_CDN_ENDPOINT: endpointBaseUrl + target.id,
HIVE_CDN_ENDPOINT: cdnEndpoint,
HIVE_CDN_KEY: cdnAccessResult.secretAccessToken,
HIVE_ENDPOINT: `http://${usageAddress}`,
HIVE_TOKEN: writeToken.secret,
HIVE_TARGET_ID: target.id,
},
});
let log = '';
await new Promise((resolve, reject) => {
routerProc.catch(err => {
if (!err.isCanceled) {
@ -88,17 +101,22 @@ plugins:
if (!routerProcOut) {
return reject(new Error('No stdout from Apollo Router process'));
}
let log = '';
routerProcOut.on('data', data => {
log += data.toString();
if (log.includes('GraphQL endpoint exposed at')) {
resolve(true);
}
process.stdout.write(log);
});
});
try {
const url = `http://localhost:${routerPort}/`;
await using _ = defer(() => {
rmSync(routerConfigPath);
routerProc.cancel();
});
const url = `http://localhost:${routerPort}/`;
async function sendOperation(i: number) {
const response = await fetch(url, {
method: 'POST',
headers: {
@ -107,13 +125,13 @@ plugins:
},
body: JSON.stringify({
query: `
query TestQuery {
me {
id
name
}
}
`,
query Query${i} {
me {
id
name
}
}
`,
}),
});
@ -127,10 +145,16 @@ plugins:
},
},
});
await waitForOperationsCollected(1);
} finally {
routerProc.cancel();
rmSync(routerConfigPath);
}
const cnt = 1000;
const jobs = [];
for (let i = 0; i < cnt; i++) {
if (i % 100 === 0) {
await new Promise(res => setTimeout(res, 500));
}
jobs.push(sendOperation(i));
}
await Promise.all(jobs);
await waitForOperationsCollected(cnt);
});
});

View file

@ -35,6 +35,7 @@ http = "1"
http-body-util = "0.1"
graphql-parser = "0.4.1"
rand = "0.9.0"
tokio-util = "0.7.16"
[dev-dependencies]
httpmock = "0.7.0"

View file

@ -21,6 +21,7 @@ fn main() {
register_plugins();
// Initialize the Hive Registry and start the Apollo Router
// TODO: Look at builder pattern in Executable::builder().start()
match HiveRegistry::new(None).and(apollo_router::main()) {
Ok(_) => {}
Err(e) => {

View file

@ -32,7 +32,7 @@ pub static PERSISTED_DOCUMENT_HASH_KEY: &str = "hive::persisted_document_hash";
pub struct Config {
pub enabled: Option<bool>,
/// GraphQL Hive persisted documents CDN endpoint URL.
pub endpoint: Option<String>,
pub endpoint: Option<EndpointConfig>,
/// GraphQL Hive persisted documents CDN access token.
pub key: Option<String>,
/// Whether arbitrary documents should be allowed along-side persisted documents.
@ -57,6 +57,25 @@ pub struct Config {
pub cache_size: Option<u64>,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum EndpointConfig {
Single(String),
Multiple(Vec<String>),
}
impl From<&str> for EndpointConfig {
fn from(value: &str) -> Self {
EndpointConfig::Single(value.into())
}
}
impl From<&[&str]> for EndpointConfig {
fn from(value: &[&str]) -> Self {
EndpointConfig::Multiple(value.iter().map(|s| s.to_string()).collect())
}
}
pub struct PersistedDocumentsPlugin {
persisted_documents_manager: Option<Arc<PersistedDocumentsManager>>,
allow_arbitrary_documents: bool,
@ -72,11 +91,14 @@ impl PersistedDocumentsPlugin {
allow_arbitrary_documents,
});
}
let endpoint = match &config.endpoint {
Some(ep) => ep.clone(),
let endpoints = match &config.endpoint {
Some(ep) => match ep {
EndpointConfig::Single(url) => vec![url.clone()],
EndpointConfig::Multiple(urls) => urls.clone(),
},
None => {
if let Ok(ep) = env::var("HIVE_CDN_ENDPOINT") {
ep
vec![ep]
} else {
return Err(
"Endpoint for persisted documents CDN is not configured. Please set it via the plugin configuration or HIVE_CDN_ENDPOINT environment variable."
@ -100,17 +122,41 @@ impl PersistedDocumentsPlugin {
}
};
let mut persisted_documents_manager = PersistedDocumentsManager::builder()
.key(key)
.user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
for endpoint in endpoints {
persisted_documents_manager = persisted_documents_manager.add_endpoint(endpoint);
}
if let Some(connect_timeout) = config.connect_timeout {
persisted_documents_manager =
persisted_documents_manager.connect_timeout(Duration::from_secs(connect_timeout));
}
if let Some(request_timeout) = config.request_timeout {
persisted_documents_manager =
persisted_documents_manager.request_timeout(Duration::from_secs(request_timeout));
}
if let Some(retry_count) = config.retry_count {
persisted_documents_manager = persisted_documents_manager.max_retries(retry_count);
}
if let Some(accept_invalid_certs) = config.accept_invalid_certs {
persisted_documents_manager =
persisted_documents_manager.accept_invalid_certs(accept_invalid_certs);
}
if let Some(cache_size) = config.cache_size {
persisted_documents_manager = persisted_documents_manager.cache_size(cache_size);
}
let persisted_documents_manager = persisted_documents_manager.build()?;
Ok(PersistedDocumentsPlugin {
persisted_documents_manager: Some(Arc::new(PersistedDocumentsManager::new(
key,
endpoint,
config.accept_invalid_certs.unwrap_or(false),
Duration::from_secs(config.connect_timeout.unwrap_or(5)),
Duration::from_secs(config.request_timeout.unwrap_or(15)),
config.retry_count.unwrap_or(3),
config.cache_size.unwrap_or(1000),
format!("hive-apollo-router/{}", PLUGIN_VERSION),
))),
persisted_documents_manager: Some(Arc::new(persisted_documents_manager)),
allow_arbitrary_documents,
})
}
@ -344,8 +390,8 @@ mod hive_persisted_documents_tests {
Self { server }
}
fn endpoint(&self) -> String {
self.server.url("")
fn endpoint(&self) -> EndpointConfig {
EndpointConfig::Single(self.server.url(""))
}
/// Registers a valid artifact URL with an actual GraphQL document

View file

@ -1,14 +1,13 @@
use crate::consts::PLUGIN_VERSION;
use crate::registry_logger::Logger;
use anyhow::{anyhow, Result};
use hive_console_sdk::supergraph_fetcher::sync::SupergraphFetcherSyncState;
use hive_console_sdk::supergraph_fetcher::SupergraphFetcher;
use hive_console_sdk::supergraph_fetcher::SupergraphFetcherSyncState;
use sha2::Digest;
use sha2::Sha256;
use std::env;
use std::io::Write;
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub struct HiveRegistry {
@ -18,7 +17,7 @@ pub struct HiveRegistry {
}
pub struct HiveRegistryConfig {
endpoint: Option<String>,
endpoints: Vec<String>,
key: Option<String>,
poll_interval: Option<u64>,
accept_invalid_certs: Option<bool>,
@ -29,7 +28,7 @@ impl HiveRegistry {
#[allow(clippy::new_ret_no_self)]
pub fn new(user_config: Option<HiveRegistryConfig>) -> Result<()> {
let mut config = HiveRegistryConfig {
endpoint: None,
endpoints: vec![],
key: None,
poll_interval: None,
accept_invalid_certs: Some(true),
@ -38,7 +37,7 @@ impl HiveRegistry {
// Pass values from user's config
if let Some(user_config) = user_config {
config.endpoint = user_config.endpoint;
config.endpoints = user_config.endpoints;
config.key = user_config.key;
config.poll_interval = user_config.poll_interval;
config.accept_invalid_certs = user_config.accept_invalid_certs;
@ -47,9 +46,9 @@ impl HiveRegistry {
// Pass values from environment variables if they are not set in the user's config
if config.endpoint.is_none() {
if config.endpoints.is_empty() {
if let Ok(endpoint) = env::var("HIVE_CDN_ENDPOINT") {
config.endpoint = Some(endpoint);
config.endpoints.push(endpoint);
}
}
@ -86,7 +85,7 @@ impl HiveRegistry {
}
// Resolve values
let endpoint = config.endpoint.unwrap_or_default();
let endpoint = config.endpoints;
let key = config.key.unwrap_or_default();
let poll_interval: u64 = config.poll_interval.unwrap_or(10);
let accept_invalid_certs = config.accept_invalid_certs.unwrap_or(false);
@ -120,19 +119,23 @@ impl HiveRegistry {
.to_string_lossy()
.to_string(),
);
env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone());
env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true");
unsafe {
env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone());
env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true");
}
let fetcher = SupergraphFetcher::try_new_sync(
endpoint,
&key,
format!("hive-apollo-router/{}", PLUGIN_VERSION),
Duration::from_secs(5),
Duration::from_secs(60),
accept_invalid_certs,
3,
)
.map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?;
let mut fetcher = SupergraphFetcher::builder()
.key(key)
.user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION))
.accept_invalid_certs(accept_invalid_certs);
for ep in endpoint {
fetcher = fetcher.add_endpoint(ep);
}
let fetcher = fetcher
.build_sync()
.map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?;
let registry = HiveRegistry {
fetcher,

View file

@ -8,8 +8,8 @@ use core::ops::Drop;
use futures::StreamExt;
use graphql_parser::parse_schema;
use graphql_parser::schema::Document;
use hive_console_sdk::agent::UsageAgentExt;
use hive_console_sdk::agent::{ExecutionReport, UsageAgent};
use hive_console_sdk::agent::usage_agent::UsageAgentExt;
use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent};
use http::HeaderValue;
use rand::Rng;
use schemars::JsonSchema;
@ -19,6 +19,7 @@ use std::env;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio_util::sync::CancellationToken;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::ServiceExt;
@ -47,11 +48,12 @@ struct OperationConfig {
pub struct UsagePlugin {
config: OperationConfig,
agent: Option<Arc<UsageAgent>>,
agent: Option<UsageAgent>,
schema: Arc<Document<'static, String>>,
cancellation_token: Arc<CancellationToken>,
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
pub struct Config {
/// Default: true
enabled: Option<bool>,
@ -95,26 +97,6 @@ pub struct Config {
flush_interval: Option<u64>,
}
impl Default for Config {
fn default() -> Self {
Self {
enabled: Some(true),
registry_token: None,
registry_usage_endpoint: Some(DEFAULT_HIVE_USAGE_ENDPOINT.into()),
sample_rate: Some(1.0),
exclude: None,
client_name_header: Some(String::from("graphql-client-name")),
client_version_header: Some(String::from("graphql-client-version")),
accept_invalid_certs: Some(false),
buffer_size: Some(1000),
connect_timeout: Some(5),
request_timeout: Some(15),
flush_interval: Some(5),
target: None,
}
}
}
impl UsagePlugin {
fn populate_context(config: OperationConfig, req: &supergraph::Request) {
let context = &req.context;
@ -179,108 +161,95 @@ impl UsagePlugin {
}
}
static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage";
#[async_trait::async_trait]
impl Plugin for UsagePlugin {
type Config = Config;
async fn new(init: PluginInit<Config>) -> Result<Self, BoxError> {
let token = init
.config
.registry_token
.clone()
.or_else(|| env::var("HIVE_TOKEN").ok());
if token.is_none() {
return Err("Hive token is required".into());
}
let endpoint = init
.config
.registry_usage_endpoint
.clone()
.unwrap_or_else(|| {
env::var("HIVE_ENDPOINT").unwrap_or(DEFAULT_HIVE_USAGE_ENDPOINT.to_string())
});
let target_id = init
.config
.target
.clone()
.or_else(|| env::var("HIVE_TARGET_ID").ok());
let default_config = Config::default();
let user_config = init.config;
let enabled = user_config
.enabled
.or(default_config.enabled)
.expect("enabled has default value");
let buffer_size = user_config
.buffer_size
.or(default_config.buffer_size)
.expect("buffer_size has no default value");
let accept_invalid_certs = user_config
.accept_invalid_certs
.or(default_config.accept_invalid_certs)
.expect("accept_invalid_certs has no default value");
let connect_timeout = user_config
.connect_timeout
.or(default_config.connect_timeout)
.expect("connect_timeout has no default value");
let request_timeout = user_config
.request_timeout
.or(default_config.request_timeout)
.expect("request_timeout has no default value");
let flush_interval = user_config
.flush_interval
.or(default_config.flush_interval)
.expect("request_timeout has no default value");
let enabled = user_config.enabled.unwrap_or(true);
if enabled {
tracing::info!("Starting GraphQL Hive Usage plugin");
}
let schema = parse_schema(&init.supergraph_sdl)
.expect("Failed to parse schema")
.into_static();
let cancellation_token = Arc::new(CancellationToken::new());
let agent = if enabled {
let flush_interval = Duration::from_secs(flush_interval);
let agent = UsageAgent::try_new(
&token.expect("token is set"),
endpoint,
target_id,
buffer_size,
Duration::from_secs(connect_timeout),
Duration::from_secs(request_timeout),
accept_invalid_certs,
flush_interval,
format!("hive-apollo-router/{}", PLUGIN_VERSION),
)
.map_err(Box::new)?;
start_flush_interval(agent.clone());
let mut agent =
UsageAgent::builder().user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
if let Some(endpoint) = user_config.registry_usage_endpoint {
agent = agent.endpoint(endpoint);
} else if let Ok(env_endpoint) = env::var("HIVE_ENDPOINT") {
agent = agent.endpoint(env_endpoint);
}
if let Some(token) = user_config.registry_token {
agent = agent.token(token);
} else if let Ok(env_token) = env::var("HIVE_TOKEN") {
agent = agent.token(env_token);
}
if let Some(target_id) = user_config.target {
agent = agent.target_id(target_id);
} else if let Ok(env_target) = env::var("HIVE_TARGET_ID") {
agent = agent.target_id(env_target);
}
if let Some(buffer_size) = user_config.buffer_size {
agent = agent.buffer_size(buffer_size);
}
if let Some(connect_timeout) = user_config.connect_timeout {
agent = agent.connect_timeout(Duration::from_secs(connect_timeout));
}
if let Some(request_timeout) = user_config.request_timeout {
agent = agent.request_timeout(Duration::from_secs(request_timeout));
}
if let Some(accept_invalid_certs) = user_config.accept_invalid_certs {
agent = agent.accept_invalid_certs(accept_invalid_certs);
}
if let Some(flush_interval) = user_config.flush_interval {
agent = agent.flush_interval(Duration::from_secs(flush_interval));
}
let agent = agent.build().map_err(Box::new)?;
let cancellation_token_for_interval = cancellation_token.clone();
let agent_for_interval = agent.clone();
tokio::task::spawn(async move {
agent_for_interval
.start_flush_interval(&cancellation_token_for_interval)
.await;
});
Some(agent)
} else {
None
};
let schema = parse_schema(&init.supergraph_sdl)
.expect("Failed to parse schema")
.into_static();
Ok(UsagePlugin {
schema: Arc::new(schema),
config: OperationConfig {
sample_rate: user_config
.sample_rate
.or(default_config.sample_rate)
.expect("sample_rate has no default value"),
exclude: user_config.exclude.or(default_config.exclude),
sample_rate: user_config.sample_rate.unwrap_or(1.0),
exclude: user_config.exclude,
client_name_header: user_config
.client_name_header
.or(default_config.client_name_header)
.expect("client_name_header has no default value"),
.unwrap_or("graphql-client-name".to_string()),
client_version_header: user_config
.client_version_header
.or(default_config.client_version_header)
.expect("client_version_header has no default value"),
.unwrap_or("graphql-client-version".to_string()),
},
agent,
cancellation_token,
})
}
@ -342,58 +311,61 @@ impl Plugin for UsagePlugin {
match result {
Err(e) => {
agent
.add_report(ExecutionReport {
schema,
client_name,
client_version,
timestamp,
duration,
ok: false,
errors: 1,
operation_body,
operation_name,
persisted_document_hash,
})
.unwrap_or_else(|e| {
tokio::spawn(async move {
let res = agent
.add_report(ExecutionReport {
schema,
client_name,
client_version,
timestamp,
duration,
ok: false,
errors: 1,
operation_body,
operation_name,
persisted_document_hash,
})
.await;
if let Err(e) = res {
tracing::error!("Error adding report: {}", e);
});
}
});
Err(e)
}
Ok(router_response) => {
let is_failure =
!router_response.response.status().is_success();
Ok(router_response.map(move |response_stream| {
let client_name = client_name.clone();
let client_version = client_version.clone();
let operation_body = operation_body.clone();
let operation_name = operation_name.clone();
let res = response_stream
.map(move |response| {
// make sure we send a single report, not for each chunk
let response_has_errors =
!response.errors.is_empty();
agent
.add_report(ExecutionReport {
schema: schema.clone(),
client_name: client_name.clone(),
client_version: client_version.clone(),
timestamp,
duration,
ok: !is_failure && !response_has_errors,
errors: response.errors.len(),
operation_body: operation_body.clone(),
operation_name: operation_name.clone(),
persisted_document_hash:
persisted_document_hash.clone(),
})
.unwrap_or_else(|e| {
let agent = agent.clone();
let execution_report = ExecutionReport {
schema: schema.clone(),
client_name: client_name.clone(),
client_version: client_version.clone(),
timestamp,
duration,
ok: !is_failure && !response_has_errors,
errors: response.errors.len(),
operation_body: operation_body.clone(),
operation_name: operation_name.clone(),
persisted_document_hash:
persisted_document_hash.clone(),
};
tokio::spawn(async move {
let res = agent
.add_report(execution_report)
.await;
if let Err(e) = res {
tracing::error!(
"Error adding report: {}",
e
);
});
}
});
response
})
@ -415,17 +387,11 @@ impl Plugin for UsagePlugin {
impl Drop for UsagePlugin {
fn drop(&mut self) {
tracing::debug!("UsagePlugin has been dropped!");
// TODO: flush the buffer
self.cancellation_token.cancel();
// Flush already done by UsageAgent's Drop impl
}
}
pub fn start_flush_interval(agent_for_interval: Arc<UsageAgent>) {
tokio::task::spawn(async move {
agent_for_interval.start_flush_interval(None).await;
});
}
#[cfg(test)]
mod hive_usage_tests {
use apollo_router::{
@ -479,7 +445,7 @@ mod hive_usage_tests {
}
fn wait_for_processing(&self) -> tokio::time::Sleep {
tokio::time::sleep(tokio::time::Duration::from_secs(1))
tokio::time::sleep(tokio::time::Duration::from_secs(2))
}
fn activate_usage_mock(&'_ self) -> Mock<'_> {
@ -584,6 +550,7 @@ mod hive_usage_tests {
instance.execute_operation(req).await.next_response().await;
instance.wait_for_processing().await;
println!("Waiting done");
mock.assert();
mock.assert_hits(1);

View file

@ -20,7 +20,7 @@ reqwest = { version = "0.12.24", default-features = false, features = [
"blocking",
] }
reqwest-retry = "0.8.0"
reqwest-middleware = "0.4.2"
reqwest-middleware = { version = "0.4.2", features = ["json"]}
anyhow = "1"
tracing = "0.1"
serde = "1"
@ -32,8 +32,15 @@ serde_json = "1"
moka = { version = "0.12.10", features = ["future", "sync"] }
sha2 = { version = "0.10.8", features = ["std"] }
tokio-util = "0.7.16"
regex-automata = "0.4.10"
once_cell = "1.21.3"
retry-policies = "0.5.0"
recloser = "1.3.1"
futures-util = "0.3.31"
typify = "0.5.0"
regress = "0.10.5"
lazy_static = "1.5.0"
async-dropper-simple = { version = "0.2.6", features = ["tokio", "no-default-bound"] }
[dev-dependencies]
mockito = "1.7.0"

View file

@ -0,0 +1,39 @@
use std::collections::VecDeque;
use tokio::sync::Mutex;
pub struct Buffer<T> {
max_size: usize,
queue: Mutex<VecDeque<T>>,
}
pub enum AddStatus<T> {
Full { drained: Vec<T> },
Ok,
}
impl<T> Buffer<T> {
pub fn new(max_size: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(max_size)),
max_size,
}
}
pub async fn add(&self, item: T) -> AddStatus<T> {
let mut queue = self.queue.lock().await;
if queue.len() >= self.max_size {
let mut drained: Vec<T> = queue.drain(..).collect();
drained.push(item);
AddStatus::Full { drained }
} else {
queue.push_back(item);
AddStatus::Ok
}
}
pub async fn drain(&self) -> Vec<T> {
let mut queue = self.queue.lock().await;
queue.drain(..).collect()
}
}

View file

@ -0,0 +1,229 @@
use std::{sync::Arc, time::Duration};
use async_dropper_simple::AsyncDropper;
use once_cell::sync::Lazy;
use recloser::AsyncRecloser;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest_middleware::ClientBuilder;
use reqwest_retry::RetryTransientMiddleware;
use crate::agent::buffer::Buffer;
use crate::agent::usage_agent::{non_empty_string, AgentError, UsageAgent, UsageAgentInner};
use crate::agent::utils::OperationProcessor;
use crate::circuit_breaker;
use retry_policies::policies::ExponentialBackoff;
pub struct UsageAgentBuilder {
token: Option<String>,
endpoint: String,
target_id: Option<String>,
buffer_size: usize,
connect_timeout: Duration,
request_timeout: Duration,
accept_invalid_certs: bool,
flush_interval: Duration,
retry_policy: ExponentialBackoff,
user_agent: Option<String>,
circuit_breaker: Option<AsyncRecloser>,
}
pub static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage";
impl Default for UsageAgentBuilder {
fn default() -> Self {
Self {
endpoint: DEFAULT_HIVE_USAGE_ENDPOINT.to_string(),
token: None,
target_id: None,
buffer_size: 1000,
connect_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(15),
accept_invalid_certs: false,
flush_interval: Duration::from_secs(5),
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
user_agent: None,
circuit_breaker: None,
}
}
}
fn is_legacy_token(token: &str) -> bool {
!token.starts_with("hvo1/") && !token.starts_with("hvu1/") && !token.starts_with("hvp1/")
}
impl UsageAgentBuilder {
/// Your [Registry Access Token](https://the-guild.dev/graphql/hive/docs/management/targets#registry-access-tokens) with write permission.
pub fn token(mut self, token: String) -> Self {
if let Some(token) = non_empty_string(Some(token)) {
self.token = Some(token);
}
self
}
/// For self-hosting, you can override `/usage` endpoint (defaults to `https://app.graphql-hive.com/usage`).
pub fn endpoint(mut self, endpoint: String) -> Self {
if let Some(endpoint) = non_empty_string(Some(endpoint)) {
self.endpoint = endpoint;
}
self
}
/// A target ID, this can either be a slug following the format “$organizationSlug/$projectSlug/$targetSlug” (e.g “the-guild/graphql-hive/staging”) or an UUID (e.g. “a0f4c605-6541-4350-8cfe-b31f21a4bf80”). To be used when the token is configured with an organization access token.
pub fn target_id(mut self, target_id: String) -> Self {
if let Some(target_id) = non_empty_string(Some(target_id)) {
self.target_id = Some(target_id);
}
self
}
/// A maximum number of operations to hold in a buffer before sending to Hive Console
/// Default: 1000
pub fn buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = buffer_size;
self
}
/// A timeout for only the connect phase of a request to Hive Console
/// Default: 5 seconds
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
self.connect_timeout = connect_timeout;
self
}
/// A timeout for the entire request to Hive Console
/// Default: 15 seconds
pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
self.request_timeout = request_timeout;
self
}
/// Accepts invalid SSL certificates
/// Default: false
pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
self.accept_invalid_certs = accept_invalid_certs;
self
}
/// Frequency of flushing the buffer to the server
/// Default: 5 seconds
pub fn flush_interval(mut self, flush_interval: Duration) -> Self {
self.flush_interval = flush_interval;
self
}
/// User-Agent header to be sent with each request
pub fn user_agent(mut self, user_agent: String) -> Self {
if let Some(user_agent) = non_empty_string(Some(user_agent)) {
self.user_agent = Some(user_agent);
}
self
}
/// Retry policy for sending reports
/// Default: ExponentialBackoff with max 3 retries
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
self.retry_policy = retry_policy;
self
}
/// Maximum number of retries for sending reports
/// Default: ExponentialBackoff with max 3 retries
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
self
}
pub(crate) fn build_agent(self) -> Result<UsageAgentInner, AgentError> {
let mut default_headers = HeaderMap::new();
default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
let token = match self.token {
Some(token) => token,
None => return Err(AgentError::MissingToken),
};
let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
.map_err(|_| AgentError::InvalidToken)?;
authorization_header.set_sensitive(true);
default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
default_headers.insert(
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let mut reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(self.accept_invalid_certs)
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout)
.default_headers(default_headers);
if let Some(user_agent) = &self.user_agent {
reqwest_agent = reqwest_agent.user_agent(user_agent);
}
let reqwest_agent = reqwest_agent
.build()
.map_err(AgentError::HTTPClientCreationError)?;
let client = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
.build();
let mut endpoint = self.endpoint;
match self.target_id {
Some(_) if is_legacy_token(&token) => return Err(AgentError::TargetIdWithLegacyToken),
Some(target_id) if !is_legacy_token(&token) => {
let target_id = validate_target_id(&target_id)?;
endpoint.push_str(&format!("/{}", target_id));
}
None if !is_legacy_token(&token) => return Err(AgentError::MissingTargetId),
_ => {}
}
let circuit_breaker = if let Some(cb) = self.circuit_breaker {
cb
} else {
circuit_breaker::CircuitBreakerBuilder::default()
.build_async()
.map_err(AgentError::CircuitBreakerCreationError)?
};
let buffer = Buffer::new(self.buffer_size);
Ok(UsageAgentInner {
endpoint,
buffer,
processor: OperationProcessor::new(),
client,
flush_interval: self.flush_interval,
circuit_breaker,
})
}
pub fn build(self) -> Result<UsageAgent, AgentError> {
let agent = self.build_agent()?;
Ok(Arc::new(AsyncDropper::new(agent)))
}
}
// Target ID regexp for validation: slug format
static SLUG_REGEX: Lazy<regex_automata::meta::Regex> = Lazy::new(|| {
regex_automata::meta::Regex::new(r"^[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+$").unwrap()
});
// Target ID regexp for validation: UUID format
static UUID_REGEX: Lazy<regex_automata::meta::Regex> = Lazy::new(|| {
regex_automata::meta::Regex::new(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
)
.unwrap()
});
fn validate_target_id(target_id: &str) -> Result<&str, AgentError> {
let trimmed_s = target_id.trim();
if trimmed_s.is_empty() {
Err(AgentError::InvalidTargetId("<empty>".to_string()))
} else {
if SLUG_REGEX.is_match(trimmed_s) {
return Ok(trimmed_s);
}
if UUID_REGEX.is_match(trimmed_s) {
return Ok(trimmed_s);
}
Err(AgentError::InvalidTargetId(format!(
"Invalid target_id format: '{}'. It must be either in slug format '$organizationSlug/$projectSlug/$targetSlug' or UUID format 'a0f4c605-6541-4350-8cfe-b31f21a4bf80'",
trimmed_s
)))
}
}

View file

@ -0,0 +1,4 @@
pub mod buffer;
pub mod builder;
pub mod usage_agent;
pub mod utils;

View file

@ -1,16 +1,18 @@
use super::graphql::OperationProcessor;
use async_dropper_simple::{AsyncDrop, AsyncDropper};
use graphql_parser::schema::Document;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use recloser::AsyncRecloser;
use reqwest_middleware::ClientWithMiddleware;
use std::{
collections::{hash_map::Entry, HashMap, VecDeque},
sync::{Arc, Mutex},
collections::{hash_map::Entry, HashMap},
sync::Arc,
time::Duration,
};
use thiserror::Error;
use tokio_util::sync::CancellationToken;
use crate::agent::{buffer::AddStatus, utils::OperationProcessor};
use crate::agent::{buffer::Buffer, builder::UsageAgentBuilder};
#[derive(Debug, Clone)]
pub struct ExecutionReport {
pub schema: Arc<Document<'static, String>>,
@ -27,44 +29,16 @@ pub struct ExecutionReport {
typify::import_types!(schema = "./usage-report-v2.schema.json");
#[derive(Debug, Default)]
pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
impl Buffer {
fn new() -> Self {
Self(Mutex::new(VecDeque::new()))
}
fn lock_buffer(
&self,
) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
buffer
}
pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
let mut buffer = self.lock_buffer()?;
buffer.push_back(report);
Ok(buffer.len())
}
pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
let mut buffer = self.lock_buffer()?;
let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
Ok(reports)
}
}
pub struct UsageAgent {
buffer_size: usize,
endpoint: String,
buffer: Buffer,
processor: OperationProcessor,
client: ClientWithMiddleware,
flush_interval: Duration,
pub struct UsageAgentInner {
pub(crate) endpoint: String,
pub(crate) buffer: Buffer<ExecutionReport>,
pub(crate) processor: OperationProcessor,
pub(crate) client: ClientWithMiddleware,
pub(crate) flush_interval: Duration,
pub(crate) circuit_breaker: AsyncRecloser,
}
fn non_empty_string(value: Option<String>) -> Option<String> {
pub fn non_empty_string(value: Option<String>) -> Option<String> {
value.filter(|str| !str.is_empty())
}
@ -72,81 +46,45 @@ fn non_empty_string(value: Option<String>) -> Option<String> {
pub enum AgentError {
#[error("unable to acquire lock: {0}")]
Lock(String),
#[error("unable to send report: token is missing")]
#[error("unable to send report: unauthorized")]
Unauthorized,
#[error("unable to send report: no access")]
Forbidden,
#[error("unable to send report: rate limited")]
RateLimited,
#[error("invalid token provided: {0}")]
InvalidToken(String),
#[error("missing token")]
MissingToken,
#[error("your access token requires providing a 'target_id' option.")]
MissingTargetId,
#[error("using 'target_id' with legacy tokens is not supported")]
TargetIdWithLegacyToken,
#[error("invalid token provided")]
InvalidToken,
#[error("invalid target id provided: {0}, it should be either a slug like \"$organizationSlug/$projectSlug/$targetSlug\" or an UUID")]
InvalidTargetId(String),
#[error("unable to instantiate the http client for reports sending: {0}")]
HTTPClientCreationError(reqwest::Error),
#[error("unable to create circuit breaker: {0}")]
CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError),
#[error("rejected by the circuit breaker")]
CircuitBreakerRejected,
#[error("unable to send report: {0}")]
Unknown(String),
}
impl UsageAgent {
#[allow(clippy::too_many_arguments)]
pub fn try_new(
token: &str,
endpoint: String,
target_id: Option<String>,
buffer_size: usize,
connect_timeout: Duration,
request_timeout: Duration,
accept_invalid_certs: bool,
flush_interval: Duration,
user_agent: String,
) -> Result<Arc<Self>, AgentError> {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
pub type UsageAgent = Arc<AsyncDropper<UsageAgentInner>>;
let mut default_headers = HeaderMap::new();
default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
.map_err(|_| AgentError::InvalidToken(token.to_string()))?;
authorization_header.set_sensitive(true);
default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
default_headers.insert(
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(accept_invalid_certs)
.connect_timeout(connect_timeout)
.timeout(request_timeout)
.user_agent(user_agent)
.default_headers(default_headers)
.build()
.map_err(AgentError::HTTPClientCreationError)?;
let client = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let mut endpoint = endpoint;
if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") {
if let Some(target_id) = target_id {
endpoint.push_str(&format!("/{}", target_id));
}
}
Ok(Arc::new(Self {
buffer_size,
endpoint,
buffer: Buffer::new(),
processor: OperationProcessor::new(),
client,
flush_interval,
}))
#[async_trait::async_trait]
pub trait UsageAgentExt {
fn builder() -> UsageAgentBuilder {
UsageAgentBuilder::default()
}
async fn flush(&self) -> Result<(), AgentError>;
async fn start_flush_interval(&self, token: &CancellationToken);
async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
}
impl UsageAgentInner {
fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
let mut report = Report {
size: 0,
@ -233,21 +171,21 @@ impl UsageAgent {
Ok(report)
}
pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
async fn send_report(&self, report: Report) -> Result<(), AgentError> {
if report.size == 0 {
return Ok(());
}
let report_body =
serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
// Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure
let resp_fut = self.client.post(&self.endpoint).json(&report).send();
let resp = self
.client
.post(&self.endpoint)
.header(reqwest::header::CONTENT_LENGTH, report_body.len())
.body(report_body)
.send()
.circuit_breaker
.call(resp_fut)
.await
.map_err(|e| AgentError::Unknown(e.to_string()))?;
.map_err(|e| match e {
recloser::Error::Inner(e) => AgentError::Unknown(e.to_string()),
recloser::Error::Rejected => AgentError::CircuitBreakerRejected,
})?;
match resp.status() {
reqwest::StatusCode::OK => Ok(()),
@ -262,67 +200,57 @@ impl UsageAgent {
}
}
pub async fn flush(&self) {
let execution_reports = match self.buffer.drain() {
Ok(res) => res,
Err(e) => {
tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
Vec::new()
}
};
let size = execution_reports.len();
if size > 0 {
match self.produce_report(execution_reports) {
Ok(report) => match self.send_report(report).await {
Ok(_) => tracing::debug!("Reported {} operations", size),
Err(e) => tracing::error!("{}", e),
},
Err(e) => tracing::error!("{}", e),
}
async fn handle_drained(&self, drained: Vec<ExecutionReport>) -> Result<(), AgentError> {
if drained.is_empty() {
return Ok(());
}
let report = self.produce_report(drained)?;
self.send_report(report).await
}
pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
let mut tokio_interval = tokio::time::interval(self.flush_interval);
match token {
Some(token) => loop {
tokio::select! {
_ = tokio_interval.tick() => { self.flush().await; }
_ = token.cancelled() => { println!("Shutting down."); return; }
}
},
None => loop {
tokio_interval.tick().await;
self.flush().await;
},
}
async fn flush(&self) -> Result<(), AgentError> {
let execution_reports = self.buffer.drain().await;
self.handle_drained(execution_reports).await?;
Ok(())
}
}
pub trait UsageAgentExt {
fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
fn flush_if_full(&self, size: usize) -> Result<(), AgentError>;
}
#[async_trait::async_trait]
impl UsageAgentExt for UsageAgent {
async fn flush(&self) -> Result<(), AgentError> {
self.inner().flush().await
}
impl UsageAgentExt for Arc<UsageAgent> {
fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
if size >= self.buffer_size {
let cloned_self = self.clone();
tokio::task::spawn(async move {
cloned_self.flush().await;
});
async fn start_flush_interval(&self, token: &CancellationToken) {
loop {
tokio::time::sleep(self.inner().flush_interval).await;
if token.is_cancelled() {
println!("Shutting down.");
return;
}
self.flush()
.await
.unwrap_or_else(|e| tracing::error!("Failed to flush usage reports: {}", e));
}
}
async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
if let AddStatus::Full { drained } = self.inner().buffer.add(execution_report).await {
self.inner().handle_drained(drained).await?;
}
Ok(())
}
}
fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
let size = self.buffer.push(execution_report)?;
self.flush_if_full(size)?;
Ok(())
#[async_trait::async_trait]
impl AsyncDrop for UsageAgentInner {
async fn async_drop(&mut self) {
if let Err(e) = self.flush().await {
tracing::error!("Failed to flush usage reports during drop: {}", e);
}
}
}
@ -333,14 +261,14 @@ mod tests {
use graphql_parser::{parse_query, parse_schema};
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
use crate::agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt};
use crate::agent::usage_agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt};
const CONTENT_TYPE_VALUE: &'static str = "application/json";
const GRAPHQL_CLIENT_NAME: &'static str = "Hive Client";
const GRAPHQL_CLIENT_VERSION: &'static str = "1.0.0";
#[tokio::test]
async fn should_send_data_to_hive() {
#[tokio::test(flavor = "multi_thread")]
async fn should_send_data_to_hive() -> Result<(), Box<dyn std::error::Error>> {
let token = "Token";
let mut server = mockito::Server::new_async().await;
@ -349,13 +277,13 @@ mod tests {
let timestamp = 1625247600;
let duration = Duration::from_millis(20);
let user_agent = format!("hive-router-sdk-test");
let user_agent = "hive-router-sdk-test";
let mock = server
.mock("POST", "/200")
.match_header(AUTHORIZATION, format!("Bearer {}", token).as_str())
.match_header(CONTENT_TYPE, CONTENT_TYPE_VALUE)
.match_header(USER_AGENT, user_agent.as_str())
.match_header(USER_AGENT, user_agent)
.match_header("X-Usage-API-Version", "2")
.match_request(move |request| {
let request_body = request.body().expect("Failed to extract body");
@ -469,8 +397,7 @@ mod tests {
CUSTOM
}
"#,
)
.expect("Failed to parse schema");
)?;
let op: graphql_tools::static_graphql::query::Document = parse_query(
r#"
@ -493,39 +420,34 @@ mod tests {
type
}
"#,
)
.expect("Failed to parse query");
)?;
let usage_agent = UsageAgent::try_new(
token,
format!("{}/200", server_url),
None,
10,
Duration::from_millis(500),
Duration::from_millis(500),
false,
Duration::from_millis(10),
user_agent,
)
.expect("Failed to create UsageAgent");
// Testing async drop
{
let usage_agent = UsageAgent::builder()
.token(token.into())
.endpoint(format!("{}/200", server_url))
.user_agent(user_agent.into())
.build()?;
usage_agent
.add_report(ExecutionReport {
schema: Arc::new(schema),
operation_body: op.to_string(),
operation_name: Some("deleteProject".to_string()),
client_name: Some(GRAPHQL_CLIENT_NAME.to_string()),
client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()),
timestamp: timestamp.try_into().unwrap(),
duration,
ok: true,
errors: 0,
persisted_document_hash: None,
})
.expect("Failed to add report");
usage_agent.flush().await;
usage_agent
.add_report(ExecutionReport {
schema: Arc::new(schema),
operation_body: op.to_string(),
operation_name: Some("deleteProject".to_string()),
client_name: Some(GRAPHQL_CLIENT_NAME.to_string()),
client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()),
timestamp,
duration,
ok: true,
errors: 0,
persisted_document_hash: None,
})
.await?;
}
mock.assert_async().await;
Ok(())
}
}

View file

@ -0,0 +1,67 @@
use std::time::Duration;
use recloser::{AsyncRecloser, Recloser};
#[derive(Clone)]
pub struct CircuitBreakerBuilder {
error_threshold: f32,
volume_threshold: usize,
reset_timeout: Duration,
}
impl Default for CircuitBreakerBuilder {
fn default() -> Self {
Self {
error_threshold: 0.5,
volume_threshold: 5,
reset_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum CircuitBreakerError {
#[error("Invalid error threshold: {0}. It must be between 0.0 and 1.0")]
InvalidErrorThreshold(f32),
}
impl CircuitBreakerBuilder {
/// Percentage after what the circuit breaker should kick in.
/// Default: .5
pub fn error_threshold(mut self, percentage: f32) -> Self {
self.error_threshold = percentage;
self
}
/// Count of requests before starting evaluating.
/// Default: 5
pub fn volume_threshold(mut self, threshold: usize) -> Self {
self.volume_threshold = threshold;
self
}
/// After what time the circuit breaker is attempting to retry sending requests in milliseconds.
/// Default: 30s
pub fn reset_timeout(mut self, timeout: Duration) -> Self {
self.reset_timeout = timeout;
self
}
pub fn build_async(self) -> Result<AsyncRecloser, CircuitBreakerError> {
let recloser = self.build_sync()?;
Ok(AsyncRecloser::from(recloser))
}
pub fn build_sync(self) -> Result<Recloser, CircuitBreakerError> {
let error_threshold = if self.error_threshold < 0.0 || self.error_threshold > 1.0 {
return Err(CircuitBreakerError::InvalidErrorThreshold(
self.error_threshold,
));
} else {
self.error_threshold
};
let recloser = Recloser::custom()
.error_rate(error_threshold)
.closed_len(self.volume_threshold)
.open_wait(self.reset_timeout)
.build();
Ok(recloser)
}
}

View file

@ -1,4 +1,4 @@
pub mod agent;
pub mod graphql;
pub mod circuit_breaker;
pub mod persisted_documents;
pub mod supergraph_fetcher;

View file

@ -1,18 +1,22 @@
use std::time::Duration;
use crate::agent::usage_agent::non_empty_string;
use crate::circuit_breaker::CircuitBreakerBuilder;
use moka::future::Cache;
use recloser::AsyncRecloser;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest_middleware::ClientBuilder;
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use reqwest_retry::RetryTransientMiddleware;
use retry_policies::policies::ExponentialBackoff;
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct PersistedDocumentsManager {
agent: ClientWithMiddleware,
client: ClientWithMiddleware,
cache: Cache<String, String>,
endpoint: String,
endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>,
}
#[derive(Debug, thiserror::Error)]
@ -31,6 +35,18 @@ pub enum PersistedDocumentsError {
FailedToReadCDNResponse(reqwest::Error),
#[error("No persisted document provided, or document id cannot be resolved.")]
PersistedDocumentRequired,
#[error("Missing required configuration option: {0}")]
MissingConfigurationOption(String),
#[error("Invalid CDN key {0}")]
InvalidCDNKey(String),
#[error("Failed to create HTTP client: {0}")]
HTTPClientCreationError(reqwest::Error),
#[error("unable to create circuit breaker: {0}")]
CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError),
#[error("rejected by the circuit breaker")]
CircuitBreakerRejected,
#[error("unknown error")]
Unknown,
}
impl PersistedDocumentsError {
@ -51,47 +67,75 @@ impl PersistedDocumentsError {
PersistedDocumentsError::PersistedDocumentRequired => {
"PERSISTED_DOCUMENT_REQUIRED".into()
}
PersistedDocumentsError::MissingConfigurationOption(_) => {
"MISSING_CONFIGURATION_OPTION".into()
}
PersistedDocumentsError::InvalidCDNKey(_) => "INVALID_CDN_KEY".into(),
PersistedDocumentsError::HTTPClientCreationError(_) => {
"HTTP_CLIENT_CREATION_ERROR".into()
}
PersistedDocumentsError::CircuitBreakerCreationError(_) => {
"CIRCUIT_BREAKER_CREATION_ERROR".into()
}
PersistedDocumentsError::CircuitBreakerRejected => "CIRCUIT_BREAKER_REJECTED".into(),
PersistedDocumentsError::Unknown => "UNKNOWN_ERROR".into(),
}
}
}
impl PersistedDocumentsManager {
#[allow(clippy::too_many_arguments)]
pub fn new(
key: String,
endpoint: String,
accept_invalid_certs: bool,
connect_timeout: Duration,
request_timeout: Duration,
retry_count: u32,
cache_size: u64,
user_agent: String,
) -> Self {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count);
let mut default_headers = HeaderMap::new();
default_headers.insert("X-Hive-CDN-Key", HeaderValue::from_str(&key).unwrap());
let reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(accept_invalid_certs)
.connect_timeout(connect_timeout)
.timeout(request_timeout)
.user_agent(user_agent)
.default_headers(default_headers)
.build()
.expect("Failed to create reqwest client");
let agent = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let cache = Cache::<String, String>::new(cache_size);
Self {
agent,
cache,
endpoint,
}
pub fn builder() -> PersistedDocumentsManagerBuilder {
PersistedDocumentsManagerBuilder::default()
}
async fn resolve_from_endpoint(
&self,
endpoint: &str,
document_id: &str,
circuit_breaker: &AsyncRecloser,
) -> Result<String, PersistedDocumentsError> {
let cdn_document_id = str::replace(document_id, "~", "/");
let cdn_artifact_url = format!("{}/apps/{}", endpoint, cdn_document_id);
info!(
"Fetching document {} from CDN: {}",
document_id, cdn_artifact_url
);
let response_fut = self.client.get(cdn_artifact_url).send();
let response = circuit_breaker
.call(response_fut)
.await
.map_err(|e| match e {
recloser::Error::Inner(e) => PersistedDocumentsError::FailedToFetchFromCDN(e),
recloser::Error::Rejected => PersistedDocumentsError::CircuitBreakerRejected,
})?;
if response.status().is_success() {
let document = response
.text()
.await
.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
debug!(
"Document fetched from CDN: {}, storing in local cache",
document
);
self.cache
.insert(document_id.into(), document.clone())
.await;
return Ok(document);
}
warn!(
"Document fetch from CDN failed: HTTP {}, Body: {:?}",
response.status(),
response
.text()
.await
.unwrap_or_else(|_| "Unavailable".to_string())
);
Err(PersistedDocumentsError::DocumentNotFound)
}
/// Resolves the document from the cache, or from the CDN
pub async fn resolve_document(
&self,
@ -110,50 +154,173 @@ impl PersistedDocumentsManager {
"Document {} not found in cache. Fetching from CDN",
document_id
);
let cdn_document_id = str::replace(document_id, "~", "/");
let cdn_artifact_url = format!("{}/apps/{}", &self.endpoint, cdn_document_id);
info!(
"Fetching document {} from CDN: {}",
document_id, cdn_artifact_url
);
let cdn_response = self.agent.get(cdn_artifact_url).send().await;
match cdn_response {
Ok(response) => {
if response.status().is_success() {
let document = response
.text()
.await
.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
debug!(
"Document fetched from CDN: {}, storing in local cache",
document
);
self.cache
.insert(document_id.into(), document.clone())
.await;
return Ok(document);
let mut last_error: Option<PersistedDocumentsError> = None;
for (endpoint, circuit_breaker) in &self.endpoints_with_circuit_breakers {
let result = self
.resolve_from_endpoint(endpoint, document_id, circuit_breaker)
.await;
match result {
Ok(document) => return Ok(document),
Err(e) => {
last_error = Some(e);
}
warn!(
"Document fetch from CDN failed: HTTP {}, Body: {:?}",
response.status(),
response
.text()
.await
.unwrap_or_else(|_| "Unavailable".to_string())
);
Err(PersistedDocumentsError::DocumentNotFound)
}
Err(e) => {
warn!("Failed to fetch document from CDN: {:?}", e);
Err(PersistedDocumentsError::FailedToFetchFromCDN(e))
}
}
match last_error {
Some(e) => Err(e),
None => Err(PersistedDocumentsError::Unknown),
}
}
}
}
}
pub struct PersistedDocumentsManagerBuilder {
key: Option<String>,
endpoints: Vec<String>,
accept_invalid_certs: bool,
connect_timeout: Duration,
request_timeout: Duration,
retry_policy: ExponentialBackoff,
cache_size: u64,
user_agent: Option<String>,
circuit_breaker: CircuitBreakerBuilder,
}
impl Default for PersistedDocumentsManagerBuilder {
fn default() -> Self {
Self {
key: None,
endpoints: vec![],
accept_invalid_certs: false,
connect_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(15),
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
cache_size: 10_000,
user_agent: None,
circuit_breaker: CircuitBreakerBuilder::default(),
}
}
}
impl PersistedDocumentsManagerBuilder {
/// The CDN Access Token with from the Hive Console target.
pub fn key(mut self, key: String) -> Self {
self.key = non_empty_string(Some(key));
self
}
/// The CDN endpoint from Hive Console target.
pub fn add_endpoint(mut self, endpoint: String) -> Self {
if let Some(endpoint) = non_empty_string(Some(endpoint)) {
self.endpoints.push(endpoint);
}
self
}
/// Accept invalid SSL certificates
/// default: false
pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
self.accept_invalid_certs = accept_invalid_certs;
self
}
/// Connection timeout for the Hive Console CDN requests.
/// Default: 5 seconds
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
self.connect_timeout = connect_timeout;
self
}
/// Request timeout for the Hive Console CDN requests.
/// Default: 15 seconds
pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
self.request_timeout = request_timeout;
self
}
/// Retry policy for fetching persisted documents
/// Default: ExponentialBackoff with max 3 retries
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
self.retry_policy = retry_policy;
self
}
/// Maximum number of retries for fetching persisted documents
/// Default: ExponentialBackoff with max 3 retries
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
self
}
/// Size of the in-memory cache for persisted documents
/// Default: 10,000 entries
pub fn cache_size(mut self, cache_size: u64) -> Self {
self.cache_size = cache_size;
self
}
/// User-Agent header to be sent with each request
pub fn user_agent(mut self, user_agent: String) -> Self {
self.user_agent = non_empty_string(Some(user_agent));
self
}
pub fn build(self) -> Result<PersistedDocumentsManager, PersistedDocumentsError> {
let mut default_headers = HeaderMap::new();
let key = match self.key {
Some(key) => key,
None => {
return Err(PersistedDocumentsError::MissingConfigurationOption(
"key".to_string(),
));
}
};
default_headers.insert(
"X-Hive-CDN-Key",
HeaderValue::from_str(&key)
.map_err(|e| PersistedDocumentsError::InvalidCDNKey(e.to_string()))?,
);
let mut reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(self.accept_invalid_certs)
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout)
.default_headers(default_headers);
if let Some(user_agent) = self.user_agent {
reqwest_agent = reqwest_agent.user_agent(user_agent);
}
let reqwest_agent = reqwest_agent
.build()
.map_err(PersistedDocumentsError::HTTPClientCreationError)?;
let client = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
.build();
let cache = Cache::<String, String>::new(self.cache_size);
if self.endpoints.is_empty() {
return Err(PersistedDocumentsError::MissingConfigurationOption(
"endpoints".to_string(),
));
}
Ok(PersistedDocumentsManager {
client,
cache,
endpoints_with_circuit_breakers: self
.endpoints
.into_iter()
.map(move |endpoint| {
let circuit_breaker = self
.circuit_breaker
.clone()
.build_async()
.map_err(PersistedDocumentsError::CircuitBreakerCreationError)?;
Ok((endpoint, circuit_breaker))
})
.collect::<Result<Vec<(String, AsyncRecloser)>, PersistedDocumentsError>>()?,
})
}
}

View file

@ -1,260 +0,0 @@
use std::fmt::Display;
use std::sync::RwLock;
use std::time::Duration;
use std::time::SystemTime;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::header::InvalidHeaderValue;
use reqwest::header::IF_NONE_MATCH;
use reqwest_middleware::ClientBuilder;
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryDecision;
use reqwest_retry::RetryPolicy;
use reqwest_retry::RetryTransientMiddleware;
#[derive(Debug)]
pub struct SupergraphFetcher<AsyncOrSync> {
client: SupergraphFetcherAsyncOrSyncClient,
endpoint: String,
etag: RwLock<Option<HeaderValue>>,
state: std::marker::PhantomData<AsyncOrSync>,
}
#[derive(Debug)]
pub struct SupergraphFetcherAsyncState;
#[derive(Debug)]
pub struct SupergraphFetcherSyncState;
#[derive(Debug)]
enum SupergraphFetcherAsyncOrSyncClient {
Async {
reqwest_client: ClientWithMiddleware,
},
Sync {
reqwest_client: reqwest::blocking::Client,
retry_policy: ExponentialBackoff,
},
}
pub enum SupergraphFetcherError {
FetcherCreationError(reqwest::Error),
NetworkError(reqwest_middleware::Error),
NetworkResponseError(reqwest::Error),
Lock(String),
InvalidKey(InvalidHeaderValue),
}
impl Display for SupergraphFetcherError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SupergraphFetcherError::FetcherCreationError(e) => {
write!(f, "Creating fetcher failed: {}", e)
}
SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e),
SupergraphFetcherError::NetworkResponseError(e) => {
write!(f, "Network response error: {}", e)
}
SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e),
SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e),
}
}
}
fn prepare_client_config(
mut endpoint: String,
key: &str,
retry_count: u32,
) -> Result<(String, HeaderMap, ExponentialBackoff), SupergraphFetcherError> {
if !endpoint.ends_with("/supergraph") {
if endpoint.ends_with("/") {
endpoint.push_str("supergraph");
} else {
endpoint.push_str("/supergraph");
}
}
let mut headers = HeaderMap::new();
let mut cdn_key_header =
HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?;
cdn_key_header.set_sensitive(true);
headers.insert("X-Hive-CDN-Key", cdn_key_header);
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count);
Ok((endpoint, headers, retry_policy))
}
impl SupergraphFetcher<SupergraphFetcherSyncState> {
#[allow(clippy::too_many_arguments)]
pub fn try_new_sync(
endpoint: String,
key: &str,
user_agent: String,
connect_timeout: Duration,
request_timeout: Duration,
accept_invalid_certs: bool,
retry_count: u32,
) -> Result<Self, SupergraphFetcherError> {
let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
Ok(Self {
client: SupergraphFetcherAsyncOrSyncClient::Sync {
reqwest_client: reqwest::blocking::Client::builder()
.danger_accept_invalid_certs(accept_invalid_certs)
.connect_timeout(connect_timeout)
.timeout(request_timeout)
.user_agent(user_agent)
.default_headers(headers)
.build()
.map_err(SupergraphFetcherError::FetcherCreationError)?,
retry_policy,
},
endpoint,
etag: RwLock::new(None),
state: std::marker::PhantomData,
})
}
pub fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
let request_start_time = SystemTime::now();
// Implementing retry logic for sync client
let mut n_past_retries = 0;
let (reqwest_client, retry_policy) = match &self.client {
SupergraphFetcherAsyncOrSyncClient::Sync {
reqwest_client,
retry_policy,
} => (reqwest_client, retry_policy),
_ => unreachable!(),
};
let resp = loop {
let mut req = reqwest_client.get(&self.endpoint);
let etag = self.get_latest_etag()?;
if let Some(etag) = etag {
req = req.header(IF_NONE_MATCH, etag);
}
let response = req.send();
match response {
Ok(resp) => break resp,
Err(e) => match retry_policy.should_retry(request_start_time, n_past_retries) {
RetryDecision::DoNotRetry => {
return Err(SupergraphFetcherError::NetworkError(
reqwest_middleware::Error::Reqwest(e),
));
}
RetryDecision::Retry { execute_after } => {
n_past_retries += 1;
if let Ok(duration) = execute_after.elapsed() {
std::thread::sleep(duration);
}
}
},
}
};
if resp.status().as_u16() == 304 {
return Ok(None);
}
let etag = resp.headers().get("etag");
self.update_latest_etag(etag)?;
let text = resp
.text()
.map_err(SupergraphFetcherError::NetworkResponseError)?;
Ok(Some(text))
}
}
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
#[allow(clippy::too_many_arguments)]
pub fn try_new_async(
endpoint: String,
key: &str,
user_agent: String,
connect_timeout: Duration,
request_timeout: Duration,
accept_invalid_certs: bool,
retry_count: u32,
) -> Result<Self, SupergraphFetcherError> {
let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
let reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(accept_invalid_certs)
.connect_timeout(connect_timeout)
.timeout(request_timeout)
.default_headers(headers)
.user_agent(user_agent)
.build()
.map_err(SupergraphFetcherError::FetcherCreationError)?;
let reqwest_client = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Ok(Self {
client: SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client },
endpoint,
etag: RwLock::new(None),
state: std::marker::PhantomData,
})
}
pub async fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
let reqwest_client = match &self.client {
SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client } => reqwest_client,
_ => unreachable!(),
};
let mut req = reqwest_client.get(&self.endpoint);
let etag = self.get_latest_etag()?;
if let Some(etag) = etag {
req = req.header(IF_NONE_MATCH, etag);
}
let resp = req
.send()
.await
.map_err(SupergraphFetcherError::NetworkError)?;
if resp.status().as_u16() == 304 {
return Ok(None);
}
let etag = resp.headers().get("etag");
self.update_latest_etag(etag)?;
let text = resp
.text()
.await
.map_err(SupergraphFetcherError::NetworkResponseError)?;
Ok(Some(text))
}
}
impl<AsyncOrSync> SupergraphFetcher<AsyncOrSync> {
fn get_latest_etag(&self) -> Result<Option<HeaderValue>, SupergraphFetcherError> {
let guard: std::sync::RwLockReadGuard<'_, Option<HeaderValue>> =
self.etag.try_read().map_err(|e| {
SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e))
})?;
Ok(guard.clone())
}
fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> {
let mut guard: std::sync::RwLockWriteGuard<'_, Option<HeaderValue>> =
self.etag.try_write().map_err(|e| {
SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e))
})?;
if let Some(etag_value) = etag {
*guard = Some(etag_value.clone());
} else {
*guard = None;
}
Ok(())
}
}

View file

@ -0,0 +1,141 @@
use futures_util::TryFutureExt;
use recloser::AsyncRecloser;
use reqwest::header::{HeaderValue, IF_NONE_MATCH};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::RetryTransientMiddleware;
use tokio::sync::RwLock;
use crate::supergraph_fetcher::{
builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError,
};
#[derive(Debug)]
pub struct SupergraphFetcherAsyncState {
endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>,
reqwest_client: ClientWithMiddleware,
}
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
pub async fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
let mut last_error: Option<SupergraphFetcherError> = None;
let mut last_resp = None;
for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers {
let mut req = self.state.reqwest_client.get(endpoint);
let etag = self.get_latest_etag().await;
if let Some(etag) = etag {
req = req.header(IF_NONE_MATCH, etag);
}
let resp_fut = async {
let mut resp = req.send().await.map_err(SupergraphFetcherError::Network);
// Server errors (5xx) are considered errors
if let Ok(ok_res) = resp {
resp = if ok_res.status().is_server_error() {
return Err(SupergraphFetcherError::Network(
reqwest_middleware::Error::Middleware(anyhow::anyhow!(
"Server error: {}",
ok_res.status()
)),
));
} else {
Ok(ok_res)
}
}
resp
};
let resp = circuit_breaker
.call(resp_fut)
// Map recloser errors to SupergraphFetcherError
.map_err(|e| match e {
recloser::Error::Inner(e) => e,
recloser::Error::Rejected => SupergraphFetcherError::RejectedByCircuitBreaker,
})
.await;
match resp {
Err(err) => {
last_error = Some(err);
continue;
}
Ok(resp) => {
last_resp = Some(resp);
break;
}
}
}
if let Some(last_resp) = last_resp {
let etag = last_resp.headers().get("etag");
self.update_latest_etag(etag).await;
let text = last_resp
.text()
.await
.map_err(SupergraphFetcherError::ResponseParse)?;
Ok(Some(text))
} else if let Some(error) = last_error {
Err(error)
} else {
Ok(None)
}
}
async fn get_latest_etag(&self) -> Option<HeaderValue> {
let guard = self.etag.read().await;
guard.clone()
}
async fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> () {
let mut guard = self.etag.write().await;
if let Some(etag_value) = etag {
*guard = Some(etag_value.clone());
} else {
*guard = None;
}
}
}
impl SupergraphFetcherBuilder {
/// Builds an asynchronous SupergraphFetcher
pub fn build_async(
self,
) -> Result<SupergraphFetcher<SupergraphFetcherAsyncState>, SupergraphFetcherError> {
self.validate_endpoints()?;
let headers = self.prepare_headers()?;
let mut reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(self.accept_invalid_certs)
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout)
.default_headers(headers);
if let Some(user_agent) = self.user_agent {
reqwest_agent = reqwest_agent.user_agent(user_agent);
}
let reqwest_agent = reqwest_agent
.build()
.map_err(SupergraphFetcherError::HTTPClientCreation)?;
let reqwest_client = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
.build();
Ok(SupergraphFetcher {
state: SupergraphFetcherAsyncState {
reqwest_client,
endpoints_with_circuit_breakers: self
.endpoints
.into_iter()
.map(|endpoint| {
let circuit_breaker = self
.circuit_breaker
.clone()
.unwrap_or_default()
.build_async()
.map_err(SupergraphFetcherError::CircuitBreakerCreation);
circuit_breaker.map(|cb| (endpoint, cb))
})
.collect::<Result<Vec<_>, _>>()?,
},
etag: RwLock::new(None),
})
}
}

View file

@ -0,0 +1,135 @@
use std::time::Duration;
use reqwest::header::{HeaderMap, HeaderValue};
use retry_policies::policies::ExponentialBackoff;
use crate::{
agent::usage_agent::non_empty_string, circuit_breaker::CircuitBreakerBuilder,
supergraph_fetcher::SupergraphFetcherError,
};
pub struct SupergraphFetcherBuilder {
pub(crate) endpoints: Vec<String>,
pub(crate) key: Option<String>,
pub(crate) user_agent: Option<String>,
pub(crate) connect_timeout: Duration,
pub(crate) request_timeout: Duration,
pub(crate) accept_invalid_certs: bool,
pub(crate) retry_policy: ExponentialBackoff,
pub(crate) circuit_breaker: Option<CircuitBreakerBuilder>,
}
impl Default for SupergraphFetcherBuilder {
fn default() -> Self {
Self {
endpoints: vec![],
key: None,
user_agent: None,
connect_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(60),
accept_invalid_certs: false,
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
circuit_breaker: None,
}
}
}
impl SupergraphFetcherBuilder {
pub fn new() -> Self {
Self::default()
}
/// The CDN endpoint from Hive Console target.
pub fn add_endpoint(mut self, endpoint: String) -> Self {
if let Some(mut endpoint) = non_empty_string(Some(endpoint)) {
if !endpoint.ends_with("/supergraph") {
if endpoint.ends_with("/") {
endpoint.push_str("supergraph");
} else {
endpoint.push_str("/supergraph");
}
}
self.endpoints.push(endpoint);
}
self
}
/// The CDN Access Token with from the Hive Console target.
pub fn key(mut self, key: String) -> Self {
self.key = Some(key);
self
}
/// User-Agent header to be sent with each request
pub fn user_agent(mut self, user_agent: String) -> Self {
self.user_agent = Some(user_agent);
self
}
/// Connection timeout for the Hive Console CDN requests.
/// Default: 5 seconds
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
/// Request timeout for the Hive Console CDN requests.
/// Default: 60 seconds
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn accept_invalid_certs(mut self, accept: bool) -> Self {
self.accept_invalid_certs = accept;
self
}
/// Policy for retrying failed requests.
///
/// By default, an exponential backoff retry policy is used, with 10 attempts.
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
self.retry_policy = retry_policy;
self
}
/// Maximum number of retries for failed requests.
///
/// By default, an exponential backoff retry policy is used, with 10 attempts.
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
self
}
pub fn circuit_breaker(&mut self, builder: CircuitBreakerBuilder) -> &mut Self {
self.circuit_breaker = Some(builder);
self
}
pub(crate) fn validate_endpoints(&self) -> Result<(), SupergraphFetcherError> {
if self.endpoints.is_empty() {
return Err(SupergraphFetcherError::MissingConfigurationOption(
"endpoint".to_string(),
));
}
Ok(())
}
pub(crate) fn prepare_headers(&self) -> Result<HeaderMap, SupergraphFetcherError> {
let key = match &self.key {
Some(key) => key,
None => {
return Err(SupergraphFetcherError::MissingConfigurationOption(
"key".to_string(),
))
}
};
let mut headers = HeaderMap::new();
let mut cdn_key_header =
HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?;
cdn_key_header.set_sensitive(true);
headers.insert("X-Hive-CDN-Key", cdn_key_header);
Ok(headers)
}
}

View file

@ -0,0 +1,51 @@
use tokio::sync::RwLock;
use tokio::sync::TryLockError;
use crate::circuit_breaker::CircuitBreakerError;
use crate::supergraph_fetcher::async_::SupergraphFetcherAsyncState;
use reqwest::header::HeaderValue;
use reqwest::header::InvalidHeaderValue;
pub mod async_;
pub mod builder;
pub mod sync;
#[derive(Debug)]
pub struct SupergraphFetcher<State> {
state: State,
etag: RwLock<Option<HeaderValue>>,
}
// Doesn't matter which one we implement this for, both have the same builder
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
pub fn builder() -> builder::SupergraphFetcherBuilder {
builder::SupergraphFetcherBuilder::default()
}
}
pub enum LockErrorType {
Read,
Write,
}
#[derive(Debug, thiserror::Error)]
pub enum SupergraphFetcherError {
#[error("Creating HTTP Client failed: {0}")]
HTTPClientCreation(reqwest::Error),
#[error("Network error: {0}")]
Network(reqwest_middleware::Error),
#[error("Parsing response failed: {0}")]
ResponseParse(reqwest::Error),
#[error("Reading the etag record failed: {0:?}")]
ETagRead(TryLockError),
#[error("Updating the etag record failed: {0:?}")]
ETagWrite(TryLockError),
#[error("Invalid CDN key: {0}")]
InvalidKey(InvalidHeaderValue),
#[error("Missing configuration option: {0}")]
MissingConfigurationOption(String),
#[error("Request rejected by circuit breaker")]
RejectedByCircuitBreaker,
#[error("Creating circuit breaker failed: {0}")]
CircuitBreakerCreation(CircuitBreakerError),
}

View file

@ -0,0 +1,191 @@
use std::time::SystemTime;
use recloser::Recloser;
use reqwest::header::{HeaderValue, IF_NONE_MATCH};
use reqwest_retry::{RetryDecision, RetryPolicy};
use retry_policies::policies::ExponentialBackoff;
use tokio::sync::RwLock;
use crate::supergraph_fetcher::{
builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError,
};
#[derive(Debug)]
pub struct SupergraphFetcherSyncState {
endpoints_with_circuit_breakers: Vec<(String, Recloser)>,
reqwest_client: reqwest::blocking::Client,
retry_policy: ExponentialBackoff,
}
impl SupergraphFetcher<SupergraphFetcherSyncState> {
pub fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
let mut last_error: Option<SupergraphFetcherError> = None;
let mut last_resp = None;
for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers {
let resp = {
circuit_breaker
.call(|| {
let request_start_time = SystemTime::now();
// Implementing retry logic for sync client
let mut n_past_retries = 0;
loop {
let mut req = self.state.reqwest_client.get(endpoint);
let etag = self.get_latest_etag()?;
if let Some(etag) = etag {
req = req.header(IF_NONE_MATCH, etag);
}
let mut response = req.send().map_err(|err| {
SupergraphFetcherError::Network(reqwest_middleware::Error::Reqwest(
err,
))
});
// Server errors (5xx) are considered retryable
if let Ok(ok_res) = response {
response = if ok_res.status().is_server_error() {
Err(SupergraphFetcherError::Network(
reqwest_middleware::Error::Middleware(anyhow::anyhow!(
"Server error: {}",
ok_res.status()
)),
))
} else {
Ok(ok_res)
}
}
match response {
Ok(resp) => break Ok(resp),
Err(e) => {
match self
.state
.retry_policy
.should_retry(request_start_time, n_past_retries)
{
RetryDecision::DoNotRetry => {
return Err(e);
}
RetryDecision::Retry { execute_after } => {
n_past_retries += 1;
match execute_after.elapsed() {
Ok(duration) => {
std::thread::sleep(duration);
}
Err(err) => {
tracing::error!(
"Error determining sleep duration for retry: {}",
err
);
// If elapsed time cannot be determined, do not wait
return Err(e);
}
}
}
}
}
}
}
})
// Map recloser errors to SupergraphFetcherError
.map_err(|e| match e {
recloser::Error::Inner(e) => e,
recloser::Error::Rejected => {
SupergraphFetcherError::RejectedByCircuitBreaker
}
})
};
match resp {
Err(e) => {
last_error = Some(e);
continue;
}
Ok(resp) => {
last_resp = Some(resp);
break;
}
}
}
if let Some(last_resp) = last_resp {
if last_resp.status().as_u16() == 304 {
return Ok(None);
}
self.update_latest_etag(last_resp.headers().get("etag"))?;
let text = last_resp
.text()
.map_err(SupergraphFetcherError::ResponseParse)?;
Ok(Some(text))
} else if let Some(error) = last_error {
Err(error)
} else {
Ok(None)
}
}
fn get_latest_etag(&self) -> Result<Option<HeaderValue>, SupergraphFetcherError> {
let guard = self
.etag
.try_read()
.map_err(SupergraphFetcherError::ETagRead)?;
Ok(guard.clone())
}
fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> {
let mut guard = self
.etag
.try_write()
.map_err(SupergraphFetcherError::ETagWrite)?;
if let Some(etag_value) = etag {
*guard = Some(etag_value.clone());
} else {
*guard = None;
}
Ok(())
}
}
impl SupergraphFetcherBuilder {
/// Builds a synchronous SupergraphFetcher
pub fn build_sync(
self,
) -> Result<SupergraphFetcher<SupergraphFetcherSyncState>, SupergraphFetcherError> {
self.validate_endpoints()?;
let headers = self.prepare_headers()?;
let mut reqwest_client = reqwest::blocking::Client::builder()
.danger_accept_invalid_certs(self.accept_invalid_certs)
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout)
.default_headers(headers);
if let Some(user_agent) = &self.user_agent {
reqwest_client = reqwest_client.user_agent(user_agent);
}
let reqwest_client = reqwest_client
.build()
.map_err(SupergraphFetcherError::HTTPClientCreation)?;
let fetcher = SupergraphFetcher {
state: SupergraphFetcherSyncState {
reqwest_client,
retry_policy: self.retry_policy,
endpoints_with_circuit_breakers: self
.endpoints
.into_iter()
.map(|endpoint| {
let circuit_breaker = self
.circuit_breaker
.clone()
.unwrap_or_default()
.build_sync()
.map_err(SupergraphFetcherError::CircuitBreakerCreation);
circuit_breaker.map(|cb| (endpoint, cb))
})
.collect::<Result<Vec<_>, _>>()?,
},
etag: RwLock::new(None),
};
Ok(fetcher)
}
}