mirror of
https://github.com/graphql-hive/console
synced 2026-04-21 14:37:17 +00:00
feat(sdk-rs): builder pattern, multiple endpoints & circuit breaker (#7379)
This commit is contained in:
parent
b90f215213
commit
b13446109d
25 changed files with 1619 additions and 750 deletions
56
.changeset/busy-cloths-search.md
Normal file
56
.changeset/busy-cloths-search.md
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
---
|
||||
'hive-console-sdk-rs': minor
|
||||
---
|
||||
|
||||
Breaking Changes to avoid future breaking changes;
|
||||
|
||||
Switch to [Builder](https://rust-unofficial.github.io/patterns/patterns/creational/builder.html) pattern for `SupergraphFetcher`, `PersistedDocumentsManager` and `UsageAgent` structs.
|
||||
|
||||
No more `try_new` or `try_new_async` or `try_new_sync` functions, instead use `SupergraphFetcherBuilder`, `PersistedDocumentsManagerBuilder` and `UsageAgentBuilder` structs to create instances.
|
||||
|
||||
Benefits;
|
||||
|
||||
- No need to provide all parameters at once when creating an instance even for default values.
|
||||
|
||||
Example;
|
||||
```rust
|
||||
// Before
|
||||
let fetcher = SupergraphFetcher::try_new_async(
|
||||
"SOME_ENDPOINT", // endpoint
|
||||
"SOME_KEY",
|
||||
"MyUserAgent/1.0".to_string(),
|
||||
Duration::from_secs(5), // connect_timeout
|
||||
Duration::from_secs(10), // request_timeout
|
||||
false, // accept_invalid_certs
|
||||
3, // retry_count
|
||||
)?;
|
||||
|
||||
// After
|
||||
// No need to provide all parameters at once, can use default values
|
||||
let fetcher = SupergraphFetcherBuilder::new()
|
||||
.endpoint("SOME_ENDPOINT".to_string())
|
||||
.key("SOME_KEY".to_string())
|
||||
.build_async()?;
|
||||
```
|
||||
|
||||
- Easier to add new configuration options in the future without breaking existing code.
|
||||
|
||||
Example;
|
||||
|
||||
```rust
|
||||
let fetcher = SupergraphFetcher::try_new_async(
|
||||
"SOME_ENDPOINT", // endpoint
|
||||
"SOME_KEY",
|
||||
"MyUserAgent/1.0".to_string(),
|
||||
Duration::from_secs(5), // connect_timeout
|
||||
Duration::from_secs(10), // request_timeout
|
||||
false, // accept_invalid_certs
|
||||
3, // retry_count
|
||||
circuit_breaker_config, // Breaking Change -> new parameter added
|
||||
)?;
|
||||
|
||||
let fetcher = SupergraphFetcherBuilder::new()
|
||||
.endpoint("SOME_ENDPOINT".to_string())
|
||||
.key("SOME_KEY".to_string())
|
||||
.build_async()?; // No breaking change, circuit_breaker_config can be added later if needed
|
||||
```
|
||||
20
.changeset/light-walls-vanish.md
Normal file
20
.changeset/light-walls-vanish.md
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
---
|
||||
'hive-console-sdk-rs': patch
|
||||
---
|
||||
|
||||
Circuit Breaker Implementation and Multiple Endpoints Support
|
||||
|
||||
Implementation of Circuit Breakers in Hive Console Rust SDK, you can learn more [here](https://the-guild.dev/graphql/hive/product-updates/2025-12-04-cdn-mirror-and-circuit-breaker)
|
||||
|
||||
Breaking Changes:
|
||||
|
||||
Now `endpoint` configuration accepts multiple endpoints as an array for `SupergraphFetcherBuilder` and `PersistedDocumentsManager`.
|
||||
|
||||
```diff
|
||||
SupergraphFetcherBuilder::default()
|
||||
- .endpoint(endpoint)
|
||||
+ .add_endpoint(endpoint1)
|
||||
+ .add_endpoint(endpoint2)
|
||||
```
|
||||
|
||||
This change requires updating the configuration structure to accommodate multiple endpoints.
|
||||
17
.changeset/violet-waves-happen.md
Normal file
17
.changeset/violet-waves-happen.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
---
|
||||
'hive-apollo-router-plugin': major
|
||||
---
|
||||
|
||||
- Multiple endpoints support for `HiveRegistry` and `PersistedOperationsPlugin`
|
||||
|
||||
Breaking Changes:
|
||||
- Now there is no `endpoint` field in the configuration, it has been replaced with `endpoints`, which is an array of strings. You are not affected if you use environment variables to set the endpoint.
|
||||
|
||||
```diff
|
||||
HiveRegistry::new(
|
||||
Some(
|
||||
HiveRegistryConfig {
|
||||
- endpoint: String::from("CDN_ENDPOINT"),
|
||||
+ endpoints: vec![String::from("CDN_ENDPOINT1"), String::from("CDN_ENDPOINT2")],
|
||||
)
|
||||
)
|
||||
1
.github/workflows/tests-integration.yaml
vendored
1
.github/workflows/tests-integration.yaml
vendored
|
|
@ -23,6 +23,7 @@ jobs:
|
|||
integration:
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# Divide integration tests into 3 shards, to run them in parallel.
|
||||
shardIndex: [1, 2, 3, 'apollo-router']
|
||||
|
|
|
|||
62
configs/cargo/Cargo.lock
generated
62
configs/cargo/Cargo.lock
generated
|
|
@ -122,9 +122,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "apollo-federation"
|
||||
version = "2.10.0"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "034b556462798d1f076d865791410c088fd549c920f2a5c0d83a0611141559b4"
|
||||
checksum = "6b26138e298ecff0fb0b972175a93174d7d87d74cdc646af6a7932dc435d86c8"
|
||||
dependencies = [
|
||||
"apollo-compiler",
|
||||
"derive_more",
|
||||
|
|
@ -171,9 +171,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "apollo-router"
|
||||
version = "2.10.0"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db37c08af3b5844565e98567e92e68ae69ecd6bb4d3e674bb2a2f27d865d3680"
|
||||
checksum = "4d5b4da5efb2fb56a7076ded5912eb829005d1c3c2fec35359855f689db3fc57"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"ahash",
|
||||
|
|
@ -292,6 +292,7 @@ dependencies = [
|
|||
"similar",
|
||||
"static_assertions",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
"sys-info",
|
||||
"sysinfo",
|
||||
"thiserror 2.0.17",
|
||||
|
|
@ -436,6 +437,19 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-dropper-simple"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7c4748dfe8cd3d625ec68fc424fa80c134319881185866f9e173af9e5d8add8"
|
||||
dependencies = [
|
||||
"async-scoped",
|
||||
"async-trait",
|
||||
"futures",
|
||||
"rustc_version",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-executor"
|
||||
version = "1.13.3"
|
||||
|
|
@ -521,6 +535,17 @@ dependencies = [
|
|||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-scoped"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-signal"
|
||||
version = "0.2.13"
|
||||
|
|
@ -2596,6 +2621,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"sha2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower 0.5.2",
|
||||
"tracing",
|
||||
]
|
||||
|
|
@ -2605,17 +2631,24 @@ name = "hive-console-sdk"
|
|||
version = "0.2.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-dropper-simple",
|
||||
"async-trait",
|
||||
"axum-core 0.5.5",
|
||||
"futures-util",
|
||||
"graphql-parser",
|
||||
"graphql-tools",
|
||||
"lazy_static",
|
||||
"md5",
|
||||
"mockito",
|
||||
"moka",
|
||||
"once_cell",
|
||||
"recloser",
|
||||
"regex-automata",
|
||||
"regress",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"reqwest-retry",
|
||||
"retry-policies",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
|
|
@ -3531,9 +3564,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "mockall"
|
||||
version = "0.14.0"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f58d964098a5f9c6b63d0798e5372fd04708193510a7af313c22e9f29b7b620b"
|
||||
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"downcast",
|
||||
|
|
@ -3545,9 +3578,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "mockall_derive"
|
||||
version = "0.14.0"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca41ce716dda6a9be188b385aa78ee5260fc25cd3802cb2a8afdc6afbe6b6dbf"
|
||||
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"proc-macro2",
|
||||
|
|
@ -4589,6 +4622,16 @@ dependencies = [
|
|||
"getrandom 0.3.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "recloser"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40ac0d06281c3556fea72cef9e5372d9ac172335be0d71c3b4f3db900483e0eb"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"pin-project",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis-protocol"
|
||||
version = "6.0.0"
|
||||
|
|
@ -5533,9 +5576,6 @@ name = "strum"
|
|||
version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
|
|
|
|||
|
|
@ -2,14 +2,13 @@ import { existsSync, rmSync, writeFileSync } from 'node:fs';
|
|||
import { createServer } from 'node:http';
|
||||
import { tmpdir } from 'node:os';
|
||||
import { join } from 'node:path';
|
||||
import { MaybePromise } from 'slonik/dist/src/types';
|
||||
import { ProjectType } from 'testkit/gql/graphql';
|
||||
import { initSeed } from 'testkit/seed';
|
||||
import { getServiceHost } from 'testkit/utils';
|
||||
import { execa } from '@esm2cjs/execa';
|
||||
|
||||
describe('Apollo Router Integration', () => {
|
||||
const getBaseEndpoint = () =>
|
||||
getServiceHost('server', 8082).then(v => `http://${v}/artifacts/v1/`);
|
||||
const getAvailablePort = () =>
|
||||
new Promise<number>(resolve => {
|
||||
const server = createServer();
|
||||
|
|
@ -18,12 +17,19 @@ describe('Apollo Router Integration', () => {
|
|||
if (address && typeof address === 'object') {
|
||||
const port = address.port;
|
||||
server.close(() => resolve(port));
|
||||
} else {
|
||||
throw new Error('Could not get available port');
|
||||
}
|
||||
});
|
||||
});
|
||||
function defer(deferFn: () => MaybePromise<void>) {
|
||||
return {
|
||||
async [Symbol.asyncDispose]() {
|
||||
return deferFn();
|
||||
},
|
||||
};
|
||||
}
|
||||
it('fetches the supergraph and sends usage reports', async () => {
|
||||
const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`);
|
||||
const endpointBaseUrl = await getBaseEndpoint();
|
||||
const { createOrg } = await initSeed().createOwner();
|
||||
const { createProject } = await createOrg();
|
||||
const { createTargetAccessToken, createCdnAccess, target, waitForOperationsCollected } =
|
||||
|
|
@ -50,6 +56,7 @@ describe('Apollo Router Integration', () => {
|
|||
.then(r => r.expectNoGraphQLErrors());
|
||||
|
||||
expect(publishSchemaResult.schemaPublish.__typename).toBe('SchemaPublishSuccess');
|
||||
|
||||
const cdnAccessResult = await createCdnAccess();
|
||||
|
||||
const usageAddress = await getServiceHost('usage', 8081);
|
||||
|
|
@ -60,6 +67,7 @@ describe('Apollo Router Integration', () => {
|
|||
`Apollo Router binary not found at path: ${routerBinPath}, make sure to build it first with 'cargo build'`,
|
||||
);
|
||||
}
|
||||
|
||||
const routerPort = await getAvailablePort();
|
||||
const routerConfigContent = `
|
||||
supergraph:
|
||||
|
|
@ -67,17 +75,22 @@ supergraph:
|
|||
plugins:
|
||||
hive.usage: {}
|
||||
`.trim();
|
||||
const routerConfigPath = join(tmpdir(), `apollo-router-config-${Date.now()}.yaml`);
|
||||
writeFileSync(routerConfigPath, routerConfigContent, 'utf-8');
|
||||
|
||||
const cdnEndpoint = await getServiceHost('server', 8082).then(
|
||||
v => `http://${v}/artifacts/v1/${target.id}`,
|
||||
);
|
||||
const routerProc = execa(routerBinPath, ['--dev', '--config', routerConfigPath], {
|
||||
all: true,
|
||||
env: {
|
||||
HIVE_CDN_ENDPOINT: endpointBaseUrl + target.id,
|
||||
HIVE_CDN_ENDPOINT: cdnEndpoint,
|
||||
HIVE_CDN_KEY: cdnAccessResult.secretAccessToken,
|
||||
HIVE_ENDPOINT: `http://${usageAddress}`,
|
||||
HIVE_TOKEN: writeToken.secret,
|
||||
HIVE_TARGET_ID: target.id,
|
||||
},
|
||||
});
|
||||
let log = '';
|
||||
await new Promise((resolve, reject) => {
|
||||
routerProc.catch(err => {
|
||||
if (!err.isCanceled) {
|
||||
|
|
@ -88,17 +101,22 @@ plugins:
|
|||
if (!routerProcOut) {
|
||||
return reject(new Error('No stdout from Apollo Router process'));
|
||||
}
|
||||
let log = '';
|
||||
routerProcOut.on('data', data => {
|
||||
log += data.toString();
|
||||
if (log.includes('GraphQL endpoint exposed at')) {
|
||||
resolve(true);
|
||||
}
|
||||
process.stdout.write(log);
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
const url = `http://localhost:${routerPort}/`;
|
||||
await using _ = defer(() => {
|
||||
rmSync(routerConfigPath);
|
||||
routerProc.cancel();
|
||||
});
|
||||
|
||||
const url = `http://localhost:${routerPort}/`;
|
||||
async function sendOperation(i: number) {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
|
@ -107,13 +125,13 @@ plugins:
|
|||
},
|
||||
body: JSON.stringify({
|
||||
query: `
|
||||
query TestQuery {
|
||||
me {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
`,
|
||||
query Query${i} {
|
||||
me {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
`,
|
||||
}),
|
||||
});
|
||||
|
||||
|
|
@ -127,10 +145,16 @@ plugins:
|
|||
},
|
||||
},
|
||||
});
|
||||
await waitForOperationsCollected(1);
|
||||
} finally {
|
||||
routerProc.cancel();
|
||||
rmSync(routerConfigPath);
|
||||
}
|
||||
const cnt = 1000;
|
||||
const jobs = [];
|
||||
for (let i = 0; i < cnt; i++) {
|
||||
if (i % 100 === 0) {
|
||||
await new Promise(res => setTimeout(res, 500));
|
||||
}
|
||||
jobs.push(sendOperation(i));
|
||||
}
|
||||
await Promise.all(jobs);
|
||||
await waitForOperationsCollected(cnt);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ http = "1"
|
|||
http-body-util = "0.1"
|
||||
graphql-parser = "0.4.1"
|
||||
rand = "0.9.0"
|
||||
tokio-util = "0.7.16"
|
||||
|
||||
[dev-dependencies]
|
||||
httpmock = "0.7.0"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ fn main() {
|
|||
register_plugins();
|
||||
|
||||
// Initialize the Hive Registry and start the Apollo Router
|
||||
// TODO: Look at builder pattern in Executable::builder().start()
|
||||
match HiveRegistry::new(None).and(apollo_router::main()) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ pub static PERSISTED_DOCUMENT_HASH_KEY: &str = "hive::persisted_document_hash";
|
|||
pub struct Config {
|
||||
pub enabled: Option<bool>,
|
||||
/// GraphQL Hive persisted documents CDN endpoint URL.
|
||||
pub endpoint: Option<String>,
|
||||
pub endpoint: Option<EndpointConfig>,
|
||||
/// GraphQL Hive persisted documents CDN access token.
|
||||
pub key: Option<String>,
|
||||
/// Whether arbitrary documents should be allowed along-side persisted documents.
|
||||
|
|
@ -57,6 +57,25 @@ pub struct Config {
|
|||
pub cache_size: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, JsonSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum EndpointConfig {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
impl From<&str> for EndpointConfig {
|
||||
fn from(value: &str) -> Self {
|
||||
EndpointConfig::Single(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[&str]> for EndpointConfig {
|
||||
fn from(value: &[&str]) -> Self {
|
||||
EndpointConfig::Multiple(value.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PersistedDocumentsPlugin {
|
||||
persisted_documents_manager: Option<Arc<PersistedDocumentsManager>>,
|
||||
allow_arbitrary_documents: bool,
|
||||
|
|
@ -72,11 +91,14 @@ impl PersistedDocumentsPlugin {
|
|||
allow_arbitrary_documents,
|
||||
});
|
||||
}
|
||||
let endpoint = match &config.endpoint {
|
||||
Some(ep) => ep.clone(),
|
||||
let endpoints = match &config.endpoint {
|
||||
Some(ep) => match ep {
|
||||
EndpointConfig::Single(url) => vec![url.clone()],
|
||||
EndpointConfig::Multiple(urls) => urls.clone(),
|
||||
},
|
||||
None => {
|
||||
if let Ok(ep) = env::var("HIVE_CDN_ENDPOINT") {
|
||||
ep
|
||||
vec![ep]
|
||||
} else {
|
||||
return Err(
|
||||
"Endpoint for persisted documents CDN is not configured. Please set it via the plugin configuration or HIVE_CDN_ENDPOINT environment variable."
|
||||
|
|
@ -100,17 +122,41 @@ impl PersistedDocumentsPlugin {
|
|||
}
|
||||
};
|
||||
|
||||
let mut persisted_documents_manager = PersistedDocumentsManager::builder()
|
||||
.key(key)
|
||||
.user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
|
||||
|
||||
for endpoint in endpoints {
|
||||
persisted_documents_manager = persisted_documents_manager.add_endpoint(endpoint);
|
||||
}
|
||||
|
||||
if let Some(connect_timeout) = config.connect_timeout {
|
||||
persisted_documents_manager =
|
||||
persisted_documents_manager.connect_timeout(Duration::from_secs(connect_timeout));
|
||||
}
|
||||
|
||||
if let Some(request_timeout) = config.request_timeout {
|
||||
persisted_documents_manager =
|
||||
persisted_documents_manager.request_timeout(Duration::from_secs(request_timeout));
|
||||
}
|
||||
|
||||
if let Some(retry_count) = config.retry_count {
|
||||
persisted_documents_manager = persisted_documents_manager.max_retries(retry_count);
|
||||
}
|
||||
|
||||
if let Some(accept_invalid_certs) = config.accept_invalid_certs {
|
||||
persisted_documents_manager =
|
||||
persisted_documents_manager.accept_invalid_certs(accept_invalid_certs);
|
||||
}
|
||||
|
||||
if let Some(cache_size) = config.cache_size {
|
||||
persisted_documents_manager = persisted_documents_manager.cache_size(cache_size);
|
||||
}
|
||||
|
||||
let persisted_documents_manager = persisted_documents_manager.build()?;
|
||||
|
||||
Ok(PersistedDocumentsPlugin {
|
||||
persisted_documents_manager: Some(Arc::new(PersistedDocumentsManager::new(
|
||||
key,
|
||||
endpoint,
|
||||
config.accept_invalid_certs.unwrap_or(false),
|
||||
Duration::from_secs(config.connect_timeout.unwrap_or(5)),
|
||||
Duration::from_secs(config.request_timeout.unwrap_or(15)),
|
||||
config.retry_count.unwrap_or(3),
|
||||
config.cache_size.unwrap_or(1000),
|
||||
format!("hive-apollo-router/{}", PLUGIN_VERSION),
|
||||
))),
|
||||
persisted_documents_manager: Some(Arc::new(persisted_documents_manager)),
|
||||
allow_arbitrary_documents,
|
||||
})
|
||||
}
|
||||
|
|
@ -344,8 +390,8 @@ mod hive_persisted_documents_tests {
|
|||
Self { server }
|
||||
}
|
||||
|
||||
fn endpoint(&self) -> String {
|
||||
self.server.url("")
|
||||
fn endpoint(&self) -> EndpointConfig {
|
||||
EndpointConfig::Single(self.server.url(""))
|
||||
}
|
||||
|
||||
/// Registers a valid artifact URL with an actual GraphQL document
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use crate::consts::PLUGIN_VERSION;
|
||||
use crate::registry_logger::Logger;
|
||||
use anyhow::{anyhow, Result};
|
||||
use hive_console_sdk::supergraph_fetcher::sync::SupergraphFetcherSyncState;
|
||||
use hive_console_sdk::supergraph_fetcher::SupergraphFetcher;
|
||||
use hive_console_sdk::supergraph_fetcher::SupergraphFetcherSyncState;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::env;
|
||||
use std::io::Write;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HiveRegistry {
|
||||
|
|
@ -18,7 +17,7 @@ pub struct HiveRegistry {
|
|||
}
|
||||
|
||||
pub struct HiveRegistryConfig {
|
||||
endpoint: Option<String>,
|
||||
endpoints: Vec<String>,
|
||||
key: Option<String>,
|
||||
poll_interval: Option<u64>,
|
||||
accept_invalid_certs: Option<bool>,
|
||||
|
|
@ -29,7 +28,7 @@ impl HiveRegistry {
|
|||
#[allow(clippy::new_ret_no_self)]
|
||||
pub fn new(user_config: Option<HiveRegistryConfig>) -> Result<()> {
|
||||
let mut config = HiveRegistryConfig {
|
||||
endpoint: None,
|
||||
endpoints: vec![],
|
||||
key: None,
|
||||
poll_interval: None,
|
||||
accept_invalid_certs: Some(true),
|
||||
|
|
@ -38,7 +37,7 @@ impl HiveRegistry {
|
|||
|
||||
// Pass values from user's config
|
||||
if let Some(user_config) = user_config {
|
||||
config.endpoint = user_config.endpoint;
|
||||
config.endpoints = user_config.endpoints;
|
||||
config.key = user_config.key;
|
||||
config.poll_interval = user_config.poll_interval;
|
||||
config.accept_invalid_certs = user_config.accept_invalid_certs;
|
||||
|
|
@ -47,9 +46,9 @@ impl HiveRegistry {
|
|||
|
||||
// Pass values from environment variables if they are not set in the user's config
|
||||
|
||||
if config.endpoint.is_none() {
|
||||
if config.endpoints.is_empty() {
|
||||
if let Ok(endpoint) = env::var("HIVE_CDN_ENDPOINT") {
|
||||
config.endpoint = Some(endpoint);
|
||||
config.endpoints.push(endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,7 +85,7 @@ impl HiveRegistry {
|
|||
}
|
||||
|
||||
// Resolve values
|
||||
let endpoint = config.endpoint.unwrap_or_default();
|
||||
let endpoint = config.endpoints;
|
||||
let key = config.key.unwrap_or_default();
|
||||
let poll_interval: u64 = config.poll_interval.unwrap_or(10);
|
||||
let accept_invalid_certs = config.accept_invalid_certs.unwrap_or(false);
|
||||
|
|
@ -120,19 +119,23 @@ impl HiveRegistry {
|
|||
.to_string_lossy()
|
||||
.to_string(),
|
||||
);
|
||||
env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone());
|
||||
env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true");
|
||||
unsafe {
|
||||
env::set_var("APOLLO_ROUTER_SUPERGRAPH_PATH", file_name.clone());
|
||||
env::set_var("APOLLO_ROUTER_HOT_RELOAD", "true");
|
||||
}
|
||||
|
||||
let fetcher = SupergraphFetcher::try_new_sync(
|
||||
endpoint,
|
||||
&key,
|
||||
format!("hive-apollo-router/{}", PLUGIN_VERSION),
|
||||
Duration::from_secs(5),
|
||||
Duration::from_secs(60),
|
||||
accept_invalid_certs,
|
||||
3,
|
||||
)
|
||||
.map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?;
|
||||
let mut fetcher = SupergraphFetcher::builder()
|
||||
.key(key)
|
||||
.user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION))
|
||||
.accept_invalid_certs(accept_invalid_certs);
|
||||
|
||||
for ep in endpoint {
|
||||
fetcher = fetcher.add_endpoint(ep);
|
||||
}
|
||||
|
||||
let fetcher = fetcher
|
||||
.build_sync()
|
||||
.map_err(|e| anyhow!("Failed to create SupergraphFetcher: {}", e))?;
|
||||
|
||||
let registry = HiveRegistry {
|
||||
fetcher,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ use core::ops::Drop;
|
|||
use futures::StreamExt;
|
||||
use graphql_parser::parse_schema;
|
||||
use graphql_parser::schema::Document;
|
||||
use hive_console_sdk::agent::UsageAgentExt;
|
||||
use hive_console_sdk::agent::{ExecutionReport, UsageAgent};
|
||||
use hive_console_sdk::agent::usage_agent::UsageAgentExt;
|
||||
use hive_console_sdk::agent::usage_agent::{ExecutionReport, UsageAgent};
|
||||
use http::HeaderValue;
|
||||
use rand::Rng;
|
||||
use schemars::JsonSchema;
|
||||
|
|
@ -19,6 +19,7 @@ use std::env;
|
|||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tower::BoxError;
|
||||
use tower::ServiceBuilder;
|
||||
use tower::ServiceExt;
|
||||
|
|
@ -47,11 +48,12 @@ struct OperationConfig {
|
|||
|
||||
pub struct UsagePlugin {
|
||||
config: OperationConfig,
|
||||
agent: Option<Arc<UsageAgent>>,
|
||||
agent: Option<UsageAgent>,
|
||||
schema: Arc<Document<'static, String>>,
|
||||
cancellation_token: Arc<CancellationToken>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, JsonSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, JsonSchema, Default)]
|
||||
pub struct Config {
|
||||
/// Default: true
|
||||
enabled: Option<bool>,
|
||||
|
|
@ -95,26 +97,6 @@ pub struct Config {
|
|||
flush_interval: Option<u64>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: Some(true),
|
||||
registry_token: None,
|
||||
registry_usage_endpoint: Some(DEFAULT_HIVE_USAGE_ENDPOINT.into()),
|
||||
sample_rate: Some(1.0),
|
||||
exclude: None,
|
||||
client_name_header: Some(String::from("graphql-client-name")),
|
||||
client_version_header: Some(String::from("graphql-client-version")),
|
||||
accept_invalid_certs: Some(false),
|
||||
buffer_size: Some(1000),
|
||||
connect_timeout: Some(5),
|
||||
request_timeout: Some(15),
|
||||
flush_interval: Some(5),
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UsagePlugin {
|
||||
fn populate_context(config: OperationConfig, req: &supergraph::Request) {
|
||||
let context = &req.context;
|
||||
|
|
@ -179,108 +161,95 @@ impl UsagePlugin {
|
|||
}
|
||||
}
|
||||
|
||||
static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage";
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Plugin for UsagePlugin {
|
||||
type Config = Config;
|
||||
|
||||
async fn new(init: PluginInit<Config>) -> Result<Self, BoxError> {
|
||||
let token = init
|
||||
.config
|
||||
.registry_token
|
||||
.clone()
|
||||
.or_else(|| env::var("HIVE_TOKEN").ok());
|
||||
|
||||
if token.is_none() {
|
||||
return Err("Hive token is required".into());
|
||||
}
|
||||
|
||||
let endpoint = init
|
||||
.config
|
||||
.registry_usage_endpoint
|
||||
.clone()
|
||||
.unwrap_or_else(|| {
|
||||
env::var("HIVE_ENDPOINT").unwrap_or(DEFAULT_HIVE_USAGE_ENDPOINT.to_string())
|
||||
});
|
||||
|
||||
let target_id = init
|
||||
.config
|
||||
.target
|
||||
.clone()
|
||||
.or_else(|| env::var("HIVE_TARGET_ID").ok());
|
||||
|
||||
let default_config = Config::default();
|
||||
let user_config = init.config;
|
||||
let enabled = user_config
|
||||
.enabled
|
||||
.or(default_config.enabled)
|
||||
.expect("enabled has default value");
|
||||
let buffer_size = user_config
|
||||
.buffer_size
|
||||
.or(default_config.buffer_size)
|
||||
.expect("buffer_size has no default value");
|
||||
let accept_invalid_certs = user_config
|
||||
.accept_invalid_certs
|
||||
.or(default_config.accept_invalid_certs)
|
||||
.expect("accept_invalid_certs has no default value");
|
||||
let connect_timeout = user_config
|
||||
.connect_timeout
|
||||
.or(default_config.connect_timeout)
|
||||
.expect("connect_timeout has no default value");
|
||||
let request_timeout = user_config
|
||||
.request_timeout
|
||||
.or(default_config.request_timeout)
|
||||
.expect("request_timeout has no default value");
|
||||
let flush_interval = user_config
|
||||
.flush_interval
|
||||
.or(default_config.flush_interval)
|
||||
.expect("request_timeout has no default value");
|
||||
|
||||
let enabled = user_config.enabled.unwrap_or(true);
|
||||
|
||||
if enabled {
|
||||
tracing::info!("Starting GraphQL Hive Usage plugin");
|
||||
}
|
||||
let schema = parse_schema(&init.supergraph_sdl)
|
||||
.expect("Failed to parse schema")
|
||||
.into_static();
|
||||
|
||||
let cancellation_token = Arc::new(CancellationToken::new());
|
||||
|
||||
let agent = if enabled {
|
||||
let flush_interval = Duration::from_secs(flush_interval);
|
||||
let agent = UsageAgent::try_new(
|
||||
&token.expect("token is set"),
|
||||
endpoint,
|
||||
target_id,
|
||||
buffer_size,
|
||||
Duration::from_secs(connect_timeout),
|
||||
Duration::from_secs(request_timeout),
|
||||
accept_invalid_certs,
|
||||
flush_interval,
|
||||
format!("hive-apollo-router/{}", PLUGIN_VERSION),
|
||||
)
|
||||
.map_err(Box::new)?;
|
||||
start_flush_interval(agent.clone());
|
||||
let mut agent =
|
||||
UsageAgent::builder().user_agent(format!("hive-apollo-router/{}", PLUGIN_VERSION));
|
||||
|
||||
if let Some(endpoint) = user_config.registry_usage_endpoint {
|
||||
agent = agent.endpoint(endpoint);
|
||||
} else if let Ok(env_endpoint) = env::var("HIVE_ENDPOINT") {
|
||||
agent = agent.endpoint(env_endpoint);
|
||||
}
|
||||
|
||||
if let Some(token) = user_config.registry_token {
|
||||
agent = agent.token(token);
|
||||
} else if let Ok(env_token) = env::var("HIVE_TOKEN") {
|
||||
agent = agent.token(env_token);
|
||||
}
|
||||
|
||||
if let Some(target_id) = user_config.target {
|
||||
agent = agent.target_id(target_id);
|
||||
} else if let Ok(env_target) = env::var("HIVE_TARGET_ID") {
|
||||
agent = agent.target_id(env_target);
|
||||
}
|
||||
|
||||
if let Some(buffer_size) = user_config.buffer_size {
|
||||
agent = agent.buffer_size(buffer_size);
|
||||
}
|
||||
|
||||
if let Some(connect_timeout) = user_config.connect_timeout {
|
||||
agent = agent.connect_timeout(Duration::from_secs(connect_timeout));
|
||||
}
|
||||
|
||||
if let Some(request_timeout) = user_config.request_timeout {
|
||||
agent = agent.request_timeout(Duration::from_secs(request_timeout));
|
||||
}
|
||||
|
||||
if let Some(accept_invalid_certs) = user_config.accept_invalid_certs {
|
||||
agent = agent.accept_invalid_certs(accept_invalid_certs);
|
||||
}
|
||||
|
||||
if let Some(flush_interval) = user_config.flush_interval {
|
||||
agent = agent.flush_interval(Duration::from_secs(flush_interval));
|
||||
}
|
||||
|
||||
let agent = agent.build().map_err(Box::new)?;
|
||||
|
||||
let cancellation_token_for_interval = cancellation_token.clone();
|
||||
let agent_for_interval = agent.clone();
|
||||
tokio::task::spawn(async move {
|
||||
agent_for_interval
|
||||
.start_flush_interval(&cancellation_token_for_interval)
|
||||
.await;
|
||||
});
|
||||
Some(agent)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let schema = parse_schema(&init.supergraph_sdl)
|
||||
.expect("Failed to parse schema")
|
||||
.into_static();
|
||||
|
||||
Ok(UsagePlugin {
|
||||
schema: Arc::new(schema),
|
||||
config: OperationConfig {
|
||||
sample_rate: user_config
|
||||
.sample_rate
|
||||
.or(default_config.sample_rate)
|
||||
.expect("sample_rate has no default value"),
|
||||
exclude: user_config.exclude.or(default_config.exclude),
|
||||
sample_rate: user_config.sample_rate.unwrap_or(1.0),
|
||||
exclude: user_config.exclude,
|
||||
client_name_header: user_config
|
||||
.client_name_header
|
||||
.or(default_config.client_name_header)
|
||||
.expect("client_name_header has no default value"),
|
||||
.unwrap_or("graphql-client-name".to_string()),
|
||||
client_version_header: user_config
|
||||
.client_version_header
|
||||
.or(default_config.client_version_header)
|
||||
.expect("client_version_header has no default value"),
|
||||
.unwrap_or("graphql-client-version".to_string()),
|
||||
},
|
||||
agent,
|
||||
cancellation_token,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -342,58 +311,61 @@ impl Plugin for UsagePlugin {
|
|||
|
||||
match result {
|
||||
Err(e) => {
|
||||
agent
|
||||
.add_report(ExecutionReport {
|
||||
schema,
|
||||
client_name,
|
||||
client_version,
|
||||
timestamp,
|
||||
duration,
|
||||
ok: false,
|
||||
errors: 1,
|
||||
operation_body,
|
||||
operation_name,
|
||||
persisted_document_hash,
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
tokio::spawn(async move {
|
||||
let res = agent
|
||||
.add_report(ExecutionReport {
|
||||
schema,
|
||||
client_name,
|
||||
client_version,
|
||||
timestamp,
|
||||
duration,
|
||||
ok: false,
|
||||
errors: 1,
|
||||
operation_body,
|
||||
operation_name,
|
||||
persisted_document_hash,
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = res {
|
||||
tracing::error!("Error adding report: {}", e);
|
||||
});
|
||||
}
|
||||
});
|
||||
Err(e)
|
||||
}
|
||||
Ok(router_response) => {
|
||||
let is_failure =
|
||||
!router_response.response.status().is_success();
|
||||
Ok(router_response.map(move |response_stream| {
|
||||
let client_name = client_name.clone();
|
||||
let client_version = client_version.clone();
|
||||
let operation_body = operation_body.clone();
|
||||
let operation_name = operation_name.clone();
|
||||
|
||||
let res = response_stream
|
||||
.map(move |response| {
|
||||
// make sure we send a single report, not for each chunk
|
||||
let response_has_errors =
|
||||
!response.errors.is_empty();
|
||||
agent
|
||||
.add_report(ExecutionReport {
|
||||
schema: schema.clone(),
|
||||
client_name: client_name.clone(),
|
||||
client_version: client_version.clone(),
|
||||
timestamp,
|
||||
duration,
|
||||
ok: !is_failure && !response_has_errors,
|
||||
errors: response.errors.len(),
|
||||
operation_body: operation_body.clone(),
|
||||
operation_name: operation_name.clone(),
|
||||
persisted_document_hash:
|
||||
persisted_document_hash.clone(),
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
let agent = agent.clone();
|
||||
let execution_report = ExecutionReport {
|
||||
schema: schema.clone(),
|
||||
client_name: client_name.clone(),
|
||||
client_version: client_version.clone(),
|
||||
timestamp,
|
||||
duration,
|
||||
ok: !is_failure && !response_has_errors,
|
||||
errors: response.errors.len(),
|
||||
operation_body: operation_body.clone(),
|
||||
operation_name: operation_name.clone(),
|
||||
persisted_document_hash:
|
||||
persisted_document_hash.clone(),
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
let res = agent
|
||||
.add_report(execution_report)
|
||||
.await;
|
||||
if let Err(e) = res {
|
||||
tracing::error!(
|
||||
"Error adding report: {}",
|
||||
e
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
response
|
||||
})
|
||||
|
|
@ -415,17 +387,11 @@ impl Plugin for UsagePlugin {
|
|||
|
||||
impl Drop for UsagePlugin {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!("UsagePlugin has been dropped!");
|
||||
// TODO: flush the buffer
|
||||
self.cancellation_token.cancel();
|
||||
// Flush already done by UsageAgent's Drop impl
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_flush_interval(agent_for_interval: Arc<UsageAgent>) {
|
||||
tokio::task::spawn(async move {
|
||||
agent_for_interval.start_flush_interval(None).await;
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod hive_usage_tests {
|
||||
use apollo_router::{
|
||||
|
|
@ -479,7 +445,7 @@ mod hive_usage_tests {
|
|||
}
|
||||
|
||||
fn wait_for_processing(&self) -> tokio::time::Sleep {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1))
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2))
|
||||
}
|
||||
|
||||
fn activate_usage_mock(&'_ self) -> Mock<'_> {
|
||||
|
|
@ -584,6 +550,7 @@ mod hive_usage_tests {
|
|||
instance.execute_operation(req).await.next_response().await;
|
||||
|
||||
instance.wait_for_processing().await;
|
||||
println!("Waiting done");
|
||||
|
||||
mock.assert();
|
||||
mock.assert_hits(1);
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ reqwest = { version = "0.12.24", default-features = false, features = [
|
|||
"blocking",
|
||||
] }
|
||||
reqwest-retry = "0.8.0"
|
||||
reqwest-middleware = "0.4.2"
|
||||
reqwest-middleware = { version = "0.4.2", features = ["json"]}
|
||||
anyhow = "1"
|
||||
tracing = "0.1"
|
||||
serde = "1"
|
||||
|
|
@ -32,8 +32,15 @@ serde_json = "1"
|
|||
moka = { version = "0.12.10", features = ["future", "sync"] }
|
||||
sha2 = { version = "0.10.8", features = ["std"] }
|
||||
tokio-util = "0.7.16"
|
||||
regex-automata = "0.4.10"
|
||||
once_cell = "1.21.3"
|
||||
retry-policies = "0.5.0"
|
||||
recloser = "1.3.1"
|
||||
futures-util = "0.3.31"
|
||||
typify = "0.5.0"
|
||||
regress = "0.10.5"
|
||||
lazy_static = "1.5.0"
|
||||
async-dropper-simple = { version = "0.2.6", features = ["tokio", "no-default-bound"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.7.0"
|
||||
|
|
|
|||
39
packages/libraries/sdk-rs/src/agent/buffer.rs
Normal file
39
packages/libraries/sdk-rs/src/agent/buffer.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
use std::collections::VecDeque;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct Buffer<T> {
|
||||
max_size: usize,
|
||||
queue: Mutex<VecDeque<T>>,
|
||||
}
|
||||
|
||||
pub enum AddStatus<T> {
|
||||
Full { drained: Vec<T> },
|
||||
Ok,
|
||||
}
|
||||
|
||||
impl<T> Buffer<T> {
|
||||
pub fn new(max_size: usize) -> Self {
|
||||
Self {
|
||||
queue: Mutex::new(VecDeque::with_capacity(max_size)),
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add(&self, item: T) -> AddStatus<T> {
|
||||
let mut queue = self.queue.lock().await;
|
||||
if queue.len() >= self.max_size {
|
||||
let mut drained: Vec<T> = queue.drain(..).collect();
|
||||
drained.push(item);
|
||||
AddStatus::Full { drained }
|
||||
} else {
|
||||
queue.push_back(item);
|
||||
AddStatus::Ok
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn drain(&self) -> Vec<T> {
|
||||
let mut queue = self.queue.lock().await;
|
||||
queue.drain(..).collect()
|
||||
}
|
||||
}
|
||||
229
packages/libraries/sdk-rs/src/agent/builder.rs
Normal file
229
packages/libraries/sdk-rs/src/agent/builder.rs
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_dropper_simple::AsyncDropper;
|
||||
use once_cell::sync::Lazy;
|
||||
use recloser::AsyncRecloser;
|
||||
use reqwest::header::{HeaderMap, HeaderValue};
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use reqwest_retry::RetryTransientMiddleware;
|
||||
|
||||
use crate::agent::buffer::Buffer;
|
||||
use crate::agent::usage_agent::{non_empty_string, AgentError, UsageAgent, UsageAgentInner};
|
||||
use crate::agent::utils::OperationProcessor;
|
||||
use crate::circuit_breaker;
|
||||
use retry_policies::policies::ExponentialBackoff;
|
||||
|
||||
pub struct UsageAgentBuilder {
|
||||
token: Option<String>,
|
||||
endpoint: String,
|
||||
target_id: Option<String>,
|
||||
buffer_size: usize,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
accept_invalid_certs: bool,
|
||||
flush_interval: Duration,
|
||||
retry_policy: ExponentialBackoff,
|
||||
user_agent: Option<String>,
|
||||
circuit_breaker: Option<AsyncRecloser>,
|
||||
}
|
||||
|
||||
pub static DEFAULT_HIVE_USAGE_ENDPOINT: &str = "https://app.graphql-hive.com/usage";
|
||||
|
||||
impl Default for UsageAgentBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
endpoint: DEFAULT_HIVE_USAGE_ENDPOINT.to_string(),
|
||||
token: None,
|
||||
target_id: None,
|
||||
buffer_size: 1000,
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
request_timeout: Duration::from_secs(15),
|
||||
accept_invalid_certs: false,
|
||||
flush_interval: Duration::from_secs(5),
|
||||
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
|
||||
user_agent: None,
|
||||
circuit_breaker: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_legacy_token(token: &str) -> bool {
|
||||
!token.starts_with("hvo1/") && !token.starts_with("hvu1/") && !token.starts_with("hvp1/")
|
||||
}
|
||||
|
||||
impl UsageAgentBuilder {
|
||||
/// Your [Registry Access Token](https://the-guild.dev/graphql/hive/docs/management/targets#registry-access-tokens) with write permission.
|
||||
pub fn token(mut self, token: String) -> Self {
|
||||
if let Some(token) = non_empty_string(Some(token)) {
|
||||
self.token = Some(token);
|
||||
}
|
||||
self
|
||||
}
|
||||
/// For self-hosting, you can override `/usage` endpoint (defaults to `https://app.graphql-hive.com/usage`).
|
||||
pub fn endpoint(mut self, endpoint: String) -> Self {
|
||||
if let Some(endpoint) = non_empty_string(Some(endpoint)) {
|
||||
self.endpoint = endpoint;
|
||||
}
|
||||
self
|
||||
}
|
||||
/// A target ID, this can either be a slug following the format “$organizationSlug/$projectSlug/$targetSlug” (e.g “the-guild/graphql-hive/staging”) or an UUID (e.g. “a0f4c605-6541-4350-8cfe-b31f21a4bf80”). To be used when the token is configured with an organization access token.
|
||||
pub fn target_id(mut self, target_id: String) -> Self {
|
||||
if let Some(target_id) = non_empty_string(Some(target_id)) {
|
||||
self.target_id = Some(target_id);
|
||||
}
|
||||
self
|
||||
}
|
||||
/// A maximum number of operations to hold in a buffer before sending to Hive Console
|
||||
/// Default: 1000
|
||||
pub fn buffer_size(mut self, buffer_size: usize) -> Self {
|
||||
self.buffer_size = buffer_size;
|
||||
self
|
||||
}
|
||||
/// A timeout for only the connect phase of a request to Hive Console
|
||||
/// Default: 5 seconds
|
||||
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
|
||||
self.connect_timeout = connect_timeout;
|
||||
self
|
||||
}
|
||||
/// A timeout for the entire request to Hive Console
|
||||
/// Default: 15 seconds
|
||||
pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
|
||||
self.request_timeout = request_timeout;
|
||||
self
|
||||
}
|
||||
/// Accepts invalid SSL certificates
|
||||
/// Default: false
|
||||
pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
|
||||
self.accept_invalid_certs = accept_invalid_certs;
|
||||
self
|
||||
}
|
||||
/// Frequency of flushing the buffer to the server
|
||||
/// Default: 5 seconds
|
||||
pub fn flush_interval(mut self, flush_interval: Duration) -> Self {
|
||||
self.flush_interval = flush_interval;
|
||||
self
|
||||
}
|
||||
/// User-Agent header to be sent with each request
|
||||
pub fn user_agent(mut self, user_agent: String) -> Self {
|
||||
if let Some(user_agent) = non_empty_string(Some(user_agent)) {
|
||||
self.user_agent = Some(user_agent);
|
||||
}
|
||||
self
|
||||
}
|
||||
/// Retry policy for sending reports
|
||||
/// Default: ExponentialBackoff with max 3 retries
|
||||
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
|
||||
self.retry_policy = retry_policy;
|
||||
self
|
||||
}
|
||||
/// Maximum number of retries for sending reports
|
||||
/// Default: ExponentialBackoff with max 3 retries
|
||||
pub fn max_retries(mut self, max_retries: u32) -> Self {
|
||||
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
|
||||
self
|
||||
}
|
||||
pub(crate) fn build_agent(self) -> Result<UsageAgentInner, AgentError> {
|
||||
let mut default_headers = HeaderMap::new();
|
||||
|
||||
default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
|
||||
|
||||
let token = match self.token {
|
||||
Some(token) => token,
|
||||
None => return Err(AgentError::MissingToken),
|
||||
};
|
||||
|
||||
let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
|
||||
.map_err(|_| AgentError::InvalidToken)?;
|
||||
|
||||
authorization_header.set_sensitive(true);
|
||||
|
||||
default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
|
||||
|
||||
default_headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let mut reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(self.accept_invalid_certs)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.timeout(self.request_timeout)
|
||||
.default_headers(default_headers);
|
||||
|
||||
if let Some(user_agent) = &self.user_agent {
|
||||
reqwest_agent = reqwest_agent.user_agent(user_agent);
|
||||
}
|
||||
|
||||
let reqwest_agent = reqwest_agent
|
||||
.build()
|
||||
.map_err(AgentError::HTTPClientCreationError)?;
|
||||
let client = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
|
||||
.build();
|
||||
|
||||
let mut endpoint = self.endpoint;
|
||||
|
||||
match self.target_id {
|
||||
Some(_) if is_legacy_token(&token) => return Err(AgentError::TargetIdWithLegacyToken),
|
||||
Some(target_id) if !is_legacy_token(&token) => {
|
||||
let target_id = validate_target_id(&target_id)?;
|
||||
endpoint.push_str(&format!("/{}", target_id));
|
||||
}
|
||||
None if !is_legacy_token(&token) => return Err(AgentError::MissingTargetId),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let circuit_breaker = if let Some(cb) = self.circuit_breaker {
|
||||
cb
|
||||
} else {
|
||||
circuit_breaker::CircuitBreakerBuilder::default()
|
||||
.build_async()
|
||||
.map_err(AgentError::CircuitBreakerCreationError)?
|
||||
};
|
||||
|
||||
let buffer = Buffer::new(self.buffer_size);
|
||||
|
||||
Ok(UsageAgentInner {
|
||||
endpoint,
|
||||
buffer,
|
||||
processor: OperationProcessor::new(),
|
||||
client,
|
||||
flush_interval: self.flush_interval,
|
||||
circuit_breaker,
|
||||
})
|
||||
}
|
||||
pub fn build(self) -> Result<UsageAgent, AgentError> {
|
||||
let agent = self.build_agent()?;
|
||||
Ok(Arc::new(AsyncDropper::new(agent)))
|
||||
}
|
||||
}
|
||||
|
||||
// Target ID regexp for validation: slug format
|
||||
static SLUG_REGEX: Lazy<regex_automata::meta::Regex> = Lazy::new(|| {
|
||||
regex_automata::meta::Regex::new(r"^[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+\/[a-zA-Z0-9-_]+$").unwrap()
|
||||
});
|
||||
// Target ID regexp for validation: UUID format
|
||||
static UUID_REGEX: Lazy<regex_automata::meta::Regex> = Lazy::new(|| {
|
||||
regex_automata::meta::Regex::new(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
fn validate_target_id(target_id: &str) -> Result<&str, AgentError> {
|
||||
let trimmed_s = target_id.trim();
|
||||
if trimmed_s.is_empty() {
|
||||
Err(AgentError::InvalidTargetId("<empty>".to_string()))
|
||||
} else {
|
||||
if SLUG_REGEX.is_match(trimmed_s) {
|
||||
return Ok(trimmed_s);
|
||||
}
|
||||
if UUID_REGEX.is_match(trimmed_s) {
|
||||
return Ok(trimmed_s);
|
||||
}
|
||||
Err(AgentError::InvalidTargetId(format!(
|
||||
"Invalid target_id format: '{}'. It must be either in slug format '$organizationSlug/$projectSlug/$targetSlug' or UUID format 'a0f4c605-6541-4350-8cfe-b31f21a4bf80'",
|
||||
trimmed_s
|
||||
)))
|
||||
}
|
||||
}
|
||||
4
packages/libraries/sdk-rs/src/agent/mod.rs
Normal file
4
packages/libraries/sdk-rs/src/agent/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod buffer;
|
||||
pub mod builder;
|
||||
pub mod usage_agent;
|
||||
pub mod utils;
|
||||
|
|
@ -1,16 +1,18 @@
|
|||
use super::graphql::OperationProcessor;
|
||||
use async_dropper_simple::{AsyncDrop, AsyncDropper};
|
||||
use graphql_parser::schema::Document;
|
||||
use reqwest::header::{HeaderMap, HeaderValue};
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use recloser::AsyncRecloser;
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap, VecDeque},
|
||||
sync::{Arc, Mutex},
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::agent::{buffer::AddStatus, utils::OperationProcessor};
|
||||
use crate::agent::{buffer::Buffer, builder::UsageAgentBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExecutionReport {
|
||||
pub schema: Arc<Document<'static, String>>,
|
||||
|
|
@ -27,44 +29,16 @@ pub struct ExecutionReport {
|
|||
|
||||
typify::import_types!(schema = "./usage-report-v2.schema.json");
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Buffer(Mutex<VecDeque<ExecutionReport>>);
|
||||
|
||||
impl Buffer {
|
||||
fn new() -> Self {
|
||||
Self(Mutex::new(VecDeque::new()))
|
||||
}
|
||||
|
||||
fn lock_buffer(
|
||||
&self,
|
||||
) -> Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> {
|
||||
let buffer: Result<std::sync::MutexGuard<'_, VecDeque<ExecutionReport>>, AgentError> =
|
||||
self.0.lock().map_err(|e| AgentError::Lock(e.to_string()));
|
||||
buffer
|
||||
}
|
||||
|
||||
pub fn push(&self, report: ExecutionReport) -> Result<usize, AgentError> {
|
||||
let mut buffer = self.lock_buffer()?;
|
||||
buffer.push_back(report);
|
||||
Ok(buffer.len())
|
||||
}
|
||||
|
||||
pub fn drain(&self) -> Result<Vec<ExecutionReport>, AgentError> {
|
||||
let mut buffer = self.lock_buffer()?;
|
||||
let reports: Vec<ExecutionReport> = buffer.drain(..).collect();
|
||||
Ok(reports)
|
||||
}
|
||||
}
|
||||
pub struct UsageAgent {
|
||||
buffer_size: usize,
|
||||
endpoint: String,
|
||||
buffer: Buffer,
|
||||
processor: OperationProcessor,
|
||||
client: ClientWithMiddleware,
|
||||
flush_interval: Duration,
|
||||
pub struct UsageAgentInner {
|
||||
pub(crate) endpoint: String,
|
||||
pub(crate) buffer: Buffer<ExecutionReport>,
|
||||
pub(crate) processor: OperationProcessor,
|
||||
pub(crate) client: ClientWithMiddleware,
|
||||
pub(crate) flush_interval: Duration,
|
||||
pub(crate) circuit_breaker: AsyncRecloser,
|
||||
}
|
||||
|
||||
fn non_empty_string(value: Option<String>) -> Option<String> {
|
||||
pub fn non_empty_string(value: Option<String>) -> Option<String> {
|
||||
value.filter(|str| !str.is_empty())
|
||||
}
|
||||
|
||||
|
|
@ -72,81 +46,45 @@ fn non_empty_string(value: Option<String>) -> Option<String> {
|
|||
pub enum AgentError {
|
||||
#[error("unable to acquire lock: {0}")]
|
||||
Lock(String),
|
||||
#[error("unable to send report: token is missing")]
|
||||
#[error("unable to send report: unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("unable to send report: no access")]
|
||||
Forbidden,
|
||||
#[error("unable to send report: rate limited")]
|
||||
RateLimited,
|
||||
#[error("invalid token provided: {0}")]
|
||||
InvalidToken(String),
|
||||
#[error("missing token")]
|
||||
MissingToken,
|
||||
#[error("your access token requires providing a 'target_id' option.")]
|
||||
MissingTargetId,
|
||||
#[error("using 'target_id' with legacy tokens is not supported")]
|
||||
TargetIdWithLegacyToken,
|
||||
#[error("invalid token provided")]
|
||||
InvalidToken,
|
||||
#[error("invalid target id provided: {0}, it should be either a slug like \"$organizationSlug/$projectSlug/$targetSlug\" or an UUID")]
|
||||
InvalidTargetId(String),
|
||||
#[error("unable to instantiate the http client for reports sending: {0}")]
|
||||
HTTPClientCreationError(reqwest::Error),
|
||||
#[error("unable to create circuit breaker: {0}")]
|
||||
CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError),
|
||||
#[error("rejected by the circuit breaker")]
|
||||
CircuitBreakerRejected,
|
||||
#[error("unable to send report: {0}")]
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl UsageAgent {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn try_new(
|
||||
token: &str,
|
||||
endpoint: String,
|
||||
target_id: Option<String>,
|
||||
buffer_size: usize,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
accept_invalid_certs: bool,
|
||||
flush_interval: Duration,
|
||||
user_agent: String,
|
||||
) -> Result<Arc<Self>, AgentError> {
|
||||
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
|
||||
pub type UsageAgent = Arc<AsyncDropper<UsageAgentInner>>;
|
||||
|
||||
let mut default_headers = HeaderMap::new();
|
||||
|
||||
default_headers.insert("X-Usage-API-Version", HeaderValue::from_static("2"));
|
||||
|
||||
let mut authorization_header = HeaderValue::from_str(&format!("Bearer {}", token))
|
||||
.map_err(|_| AgentError::InvalidToken(token.to_string()))?;
|
||||
|
||||
authorization_header.set_sensitive(true);
|
||||
|
||||
default_headers.insert(reqwest::header::AUTHORIZATION, authorization_header);
|
||||
|
||||
default_headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.connect_timeout(connect_timeout)
|
||||
.timeout(request_timeout)
|
||||
.user_agent(user_agent)
|
||||
.default_headers(default_headers)
|
||||
.build()
|
||||
.map_err(AgentError::HTTPClientCreationError)?;
|
||||
let client = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
let mut endpoint = endpoint;
|
||||
|
||||
if token.starts_with("hvo1/") || token.starts_with("hvu1/") || token.starts_with("hvp1/") {
|
||||
if let Some(target_id) = target_id {
|
||||
endpoint.push_str(&format!("/{}", target_id));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
buffer_size,
|
||||
endpoint,
|
||||
buffer: Buffer::new(),
|
||||
processor: OperationProcessor::new(),
|
||||
client,
|
||||
flush_interval,
|
||||
}))
|
||||
#[async_trait::async_trait]
|
||||
pub trait UsageAgentExt {
|
||||
fn builder() -> UsageAgentBuilder {
|
||||
UsageAgentBuilder::default()
|
||||
}
|
||||
async fn flush(&self) -> Result<(), AgentError>;
|
||||
async fn start_flush_interval(&self, token: &CancellationToken);
|
||||
async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
|
||||
}
|
||||
|
||||
impl UsageAgentInner {
|
||||
fn produce_report(&self, reports: Vec<ExecutionReport>) -> Result<Report, AgentError> {
|
||||
let mut report = Report {
|
||||
size: 0,
|
||||
|
|
@ -233,21 +171,21 @@ impl UsageAgent {
|
|||
Ok(report)
|
||||
}
|
||||
|
||||
pub async fn send_report(&self, report: Report) -> Result<(), AgentError> {
|
||||
async fn send_report(&self, report: Report) -> Result<(), AgentError> {
|
||||
if report.size == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let report_body =
|
||||
serde_json::to_vec(&report).map_err(|e| AgentError::Unknown(e.to_string()))?;
|
||||
// Based on https://the-guild.dev/graphql/hive/docs/specs/usage-reports#data-structure
|
||||
let resp_fut = self.client.post(&self.endpoint).json(&report).send();
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&self.endpoint)
|
||||
.header(reqwest::header::CONTENT_LENGTH, report_body.len())
|
||||
.body(report_body)
|
||||
.send()
|
||||
.circuit_breaker
|
||||
.call(resp_fut)
|
||||
.await
|
||||
.map_err(|e| AgentError::Unknown(e.to_string()))?;
|
||||
.map_err(|e| match e {
|
||||
recloser::Error::Inner(e) => AgentError::Unknown(e.to_string()),
|
||||
recloser::Error::Rejected => AgentError::CircuitBreakerRejected,
|
||||
})?;
|
||||
|
||||
match resp.status() {
|
||||
reqwest::StatusCode::OK => Ok(()),
|
||||
|
|
@ -262,67 +200,57 @@ impl UsageAgent {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn flush(&self) {
|
||||
let execution_reports = match self.buffer.drain() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to acquire lock for State in drain_reports: {}", e);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
let size = execution_reports.len();
|
||||
|
||||
if size > 0 {
|
||||
match self.produce_report(execution_reports) {
|
||||
Ok(report) => match self.send_report(report).await {
|
||||
Ok(_) => tracing::debug!("Reported {} operations", size),
|
||||
Err(e) => tracing::error!("{}", e),
|
||||
},
|
||||
Err(e) => tracing::error!("{}", e),
|
||||
}
|
||||
async fn handle_drained(&self, drained: Vec<ExecutionReport>) -> Result<(), AgentError> {
|
||||
if drained.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let report = self.produce_report(drained)?;
|
||||
self.send_report(report).await
|
||||
}
|
||||
pub async fn start_flush_interval(&self, token: Option<CancellationToken>) {
|
||||
let mut tokio_interval = tokio::time::interval(self.flush_interval);
|
||||
|
||||
match token {
|
||||
Some(token) => loop {
|
||||
tokio::select! {
|
||||
_ = tokio_interval.tick() => { self.flush().await; }
|
||||
_ = token.cancelled() => { println!("Shutting down."); return; }
|
||||
}
|
||||
},
|
||||
None => loop {
|
||||
tokio_interval.tick().await;
|
||||
self.flush().await;
|
||||
},
|
||||
}
|
||||
async fn flush(&self) -> Result<(), AgentError> {
|
||||
let execution_reports = self.buffer.drain().await;
|
||||
|
||||
self.handle_drained(execution_reports).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait UsageAgentExt {
|
||||
fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError>;
|
||||
fn flush_if_full(&self, size: usize) -> Result<(), AgentError>;
|
||||
}
|
||||
#[async_trait::async_trait]
|
||||
impl UsageAgentExt for UsageAgent {
|
||||
async fn flush(&self) -> Result<(), AgentError> {
|
||||
self.inner().flush().await
|
||||
}
|
||||
|
||||
impl UsageAgentExt for Arc<UsageAgent> {
|
||||
fn flush_if_full(&self, size: usize) -> Result<(), AgentError> {
|
||||
if size >= self.buffer_size {
|
||||
let cloned_self = self.clone();
|
||||
tokio::task::spawn(async move {
|
||||
cloned_self.flush().await;
|
||||
});
|
||||
async fn start_flush_interval(&self, token: &CancellationToken) {
|
||||
loop {
|
||||
tokio::time::sleep(self.inner().flush_interval).await;
|
||||
if token.is_cancelled() {
|
||||
println!("Shutting down.");
|
||||
return;
|
||||
}
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::error!("Failed to flush usage reports: {}", e));
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
|
||||
if let AddStatus::Full { drained } = self.inner().buffer.add(execution_report).await {
|
||||
self.inner().handle_drained(drained).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn add_report(&self, execution_report: ExecutionReport) -> Result<(), AgentError> {
|
||||
let size = self.buffer.push(execution_report)?;
|
||||
|
||||
self.flush_if_full(size)?;
|
||||
|
||||
Ok(())
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncDrop for UsageAgentInner {
|
||||
async fn async_drop(&mut self) {
|
||||
if let Err(e) = self.flush().await {
|
||||
tracing::error!("Failed to flush usage reports during drop: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -333,14 +261,14 @@ mod tests {
|
|||
use graphql_parser::{parse_query, parse_schema};
|
||||
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
|
||||
|
||||
use crate::agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt};
|
||||
use crate::agent::usage_agent::{ExecutionReport, Report, UsageAgent, UsageAgentExt};
|
||||
|
||||
const CONTENT_TYPE_VALUE: &'static str = "application/json";
|
||||
const GRAPHQL_CLIENT_NAME: &'static str = "Hive Client";
|
||||
const GRAPHQL_CLIENT_VERSION: &'static str = "1.0.0";
|
||||
|
||||
#[tokio::test]
|
||||
async fn should_send_data_to_hive() {
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn should_send_data_to_hive() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = "Token";
|
||||
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
|
@ -349,13 +277,13 @@ mod tests {
|
|||
|
||||
let timestamp = 1625247600;
|
||||
let duration = Duration::from_millis(20);
|
||||
let user_agent = format!("hive-router-sdk-test");
|
||||
let user_agent = "hive-router-sdk-test";
|
||||
|
||||
let mock = server
|
||||
.mock("POST", "/200")
|
||||
.match_header(AUTHORIZATION, format!("Bearer {}", token).as_str())
|
||||
.match_header(CONTENT_TYPE, CONTENT_TYPE_VALUE)
|
||||
.match_header(USER_AGENT, user_agent.as_str())
|
||||
.match_header(USER_AGENT, user_agent)
|
||||
.match_header("X-Usage-API-Version", "2")
|
||||
.match_request(move |request| {
|
||||
let request_body = request.body().expect("Failed to extract body");
|
||||
|
|
@ -469,8 +397,7 @@ mod tests {
|
|||
CUSTOM
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.expect("Failed to parse schema");
|
||||
)?;
|
||||
|
||||
let op: graphql_tools::static_graphql::query::Document = parse_query(
|
||||
r#"
|
||||
|
|
@ -493,39 +420,34 @@ mod tests {
|
|||
type
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.expect("Failed to parse query");
|
||||
)?;
|
||||
|
||||
let usage_agent = UsageAgent::try_new(
|
||||
token,
|
||||
format!("{}/200", server_url),
|
||||
None,
|
||||
10,
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(500),
|
||||
false,
|
||||
Duration::from_millis(10),
|
||||
user_agent,
|
||||
)
|
||||
.expect("Failed to create UsageAgent");
|
||||
// Testing async drop
|
||||
{
|
||||
let usage_agent = UsageAgent::builder()
|
||||
.token(token.into())
|
||||
.endpoint(format!("{}/200", server_url))
|
||||
.user_agent(user_agent.into())
|
||||
.build()?;
|
||||
|
||||
usage_agent
|
||||
.add_report(ExecutionReport {
|
||||
schema: Arc::new(schema),
|
||||
operation_body: op.to_string(),
|
||||
operation_name: Some("deleteProject".to_string()),
|
||||
client_name: Some(GRAPHQL_CLIENT_NAME.to_string()),
|
||||
client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()),
|
||||
timestamp: timestamp.try_into().unwrap(),
|
||||
duration,
|
||||
ok: true,
|
||||
errors: 0,
|
||||
persisted_document_hash: None,
|
||||
})
|
||||
.expect("Failed to add report");
|
||||
|
||||
usage_agent.flush().await;
|
||||
usage_agent
|
||||
.add_report(ExecutionReport {
|
||||
schema: Arc::new(schema),
|
||||
operation_body: op.to_string(),
|
||||
operation_name: Some("deleteProject".to_string()),
|
||||
client_name: Some(GRAPHQL_CLIENT_NAME.to_string()),
|
||||
client_version: Some(GRAPHQL_CLIENT_VERSION.to_string()),
|
||||
timestamp,
|
||||
duration,
|
||||
ok: true,
|
||||
errors: 0,
|
||||
persisted_document_hash: None,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
mock.assert_async().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
67
packages/libraries/sdk-rs/src/circuit_breaker.rs
Normal file
67
packages/libraries/sdk-rs/src/circuit_breaker.rs
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use recloser::{AsyncRecloser, Recloser};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CircuitBreakerBuilder {
|
||||
error_threshold: f32,
|
||||
volume_threshold: usize,
|
||||
reset_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreakerBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
error_threshold: 0.5,
|
||||
volume_threshold: 5,
|
||||
reset_timeout: Duration::from_secs(30),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CircuitBreakerError {
|
||||
#[error("Invalid error threshold: {0}. It must be between 0.0 and 1.0")]
|
||||
InvalidErrorThreshold(f32),
|
||||
}
|
||||
|
||||
impl CircuitBreakerBuilder {
|
||||
/// Percentage after what the circuit breaker should kick in.
|
||||
/// Default: .5
|
||||
pub fn error_threshold(mut self, percentage: f32) -> Self {
|
||||
self.error_threshold = percentage;
|
||||
self
|
||||
}
|
||||
/// Count of requests before starting evaluating.
|
||||
/// Default: 5
|
||||
pub fn volume_threshold(mut self, threshold: usize) -> Self {
|
||||
self.volume_threshold = threshold;
|
||||
self
|
||||
}
|
||||
/// After what time the circuit breaker is attempting to retry sending requests in milliseconds.
|
||||
/// Default: 30s
|
||||
pub fn reset_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.reset_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build_async(self) -> Result<AsyncRecloser, CircuitBreakerError> {
|
||||
let recloser = self.build_sync()?;
|
||||
Ok(AsyncRecloser::from(recloser))
|
||||
}
|
||||
pub fn build_sync(self) -> Result<Recloser, CircuitBreakerError> {
|
||||
let error_threshold = if self.error_threshold < 0.0 || self.error_threshold > 1.0 {
|
||||
return Err(CircuitBreakerError::InvalidErrorThreshold(
|
||||
self.error_threshold,
|
||||
));
|
||||
} else {
|
||||
self.error_threshold
|
||||
};
|
||||
let recloser = Recloser::custom()
|
||||
.error_rate(error_threshold)
|
||||
.closed_len(self.volume_threshold)
|
||||
.open_wait(self.reset_timeout)
|
||||
.build();
|
||||
Ok(recloser)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
pub mod agent;
|
||||
pub mod graphql;
|
||||
pub mod circuit_breaker;
|
||||
pub mod persisted_documents;
|
||||
pub mod supergraph_fetcher;
|
||||
|
|
|
|||
|
|
@ -1,18 +1,22 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use crate::agent::usage_agent::non_empty_string;
|
||||
use crate::circuit_breaker::CircuitBreakerBuilder;
|
||||
use moka::future::Cache;
|
||||
use recloser::AsyncRecloser;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use reqwest_retry::RetryTransientMiddleware;
|
||||
use retry_policies::policies::ExponentialBackoff;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PersistedDocumentsManager {
|
||||
agent: ClientWithMiddleware,
|
||||
client: ClientWithMiddleware,
|
||||
cache: Cache<String, String>,
|
||||
endpoint: String,
|
||||
endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -31,6 +35,18 @@ pub enum PersistedDocumentsError {
|
|||
FailedToReadCDNResponse(reqwest::Error),
|
||||
#[error("No persisted document provided, or document id cannot be resolved.")]
|
||||
PersistedDocumentRequired,
|
||||
#[error("Missing required configuration option: {0}")]
|
||||
MissingConfigurationOption(String),
|
||||
#[error("Invalid CDN key {0}")]
|
||||
InvalidCDNKey(String),
|
||||
#[error("Failed to create HTTP client: {0}")]
|
||||
HTTPClientCreationError(reqwest::Error),
|
||||
#[error("unable to create circuit breaker: {0}")]
|
||||
CircuitBreakerCreationError(#[from] crate::circuit_breaker::CircuitBreakerError),
|
||||
#[error("rejected by the circuit breaker")]
|
||||
CircuitBreakerRejected,
|
||||
#[error("unknown error")]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl PersistedDocumentsError {
|
||||
|
|
@ -51,47 +67,75 @@ impl PersistedDocumentsError {
|
|||
PersistedDocumentsError::PersistedDocumentRequired => {
|
||||
"PERSISTED_DOCUMENT_REQUIRED".into()
|
||||
}
|
||||
PersistedDocumentsError::MissingConfigurationOption(_) => {
|
||||
"MISSING_CONFIGURATION_OPTION".into()
|
||||
}
|
||||
PersistedDocumentsError::InvalidCDNKey(_) => "INVALID_CDN_KEY".into(),
|
||||
PersistedDocumentsError::HTTPClientCreationError(_) => {
|
||||
"HTTP_CLIENT_CREATION_ERROR".into()
|
||||
}
|
||||
PersistedDocumentsError::CircuitBreakerCreationError(_) => {
|
||||
"CIRCUIT_BREAKER_CREATION_ERROR".into()
|
||||
}
|
||||
PersistedDocumentsError::CircuitBreakerRejected => "CIRCUIT_BREAKER_REJECTED".into(),
|
||||
PersistedDocumentsError::Unknown => "UNKNOWN_ERROR".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PersistedDocumentsManager {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
key: String,
|
||||
endpoint: String,
|
||||
accept_invalid_certs: bool,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
retry_count: u32,
|
||||
cache_size: u64,
|
||||
user_agent: String,
|
||||
) -> Self {
|
||||
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count);
|
||||
|
||||
let mut default_headers = HeaderMap::new();
|
||||
default_headers.insert("X-Hive-CDN-Key", HeaderValue::from_str(&key).unwrap());
|
||||
let reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.connect_timeout(connect_timeout)
|
||||
.timeout(request_timeout)
|
||||
.user_agent(user_agent)
|
||||
.default_headers(default_headers)
|
||||
.build()
|
||||
.expect("Failed to create reqwest client");
|
||||
let agent = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
let cache = Cache::<String, String>::new(cache_size);
|
||||
|
||||
Self {
|
||||
agent,
|
||||
cache,
|
||||
endpoint,
|
||||
}
|
||||
pub fn builder() -> PersistedDocumentsManagerBuilder {
|
||||
PersistedDocumentsManagerBuilder::default()
|
||||
}
|
||||
async fn resolve_from_endpoint(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
document_id: &str,
|
||||
circuit_breaker: &AsyncRecloser,
|
||||
) -> Result<String, PersistedDocumentsError> {
|
||||
let cdn_document_id = str::replace(document_id, "~", "/");
|
||||
let cdn_artifact_url = format!("{}/apps/{}", endpoint, cdn_document_id);
|
||||
info!(
|
||||
"Fetching document {} from CDN: {}",
|
||||
document_id, cdn_artifact_url
|
||||
);
|
||||
let response_fut = self.client.get(cdn_artifact_url).send();
|
||||
|
||||
let response = circuit_breaker
|
||||
.call(response_fut)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
recloser::Error::Inner(e) => PersistedDocumentsError::FailedToFetchFromCDN(e),
|
||||
recloser::Error::Rejected => PersistedDocumentsError::CircuitBreakerRejected,
|
||||
})?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let document = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
|
||||
debug!(
|
||||
"Document fetched from CDN: {}, storing in local cache",
|
||||
document
|
||||
);
|
||||
self.cache
|
||||
.insert(document_id.into(), document.clone())
|
||||
.await;
|
||||
|
||||
return Ok(document);
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Document fetch from CDN failed: HTTP {}, Body: {:?}",
|
||||
response.status(),
|
||||
response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unavailable".to_string())
|
||||
);
|
||||
|
||||
Err(PersistedDocumentsError::DocumentNotFound)
|
||||
}
|
||||
/// Resolves the document from the cache, or from the CDN
|
||||
pub async fn resolve_document(
|
||||
&self,
|
||||
|
|
@ -110,50 +154,173 @@ impl PersistedDocumentsManager {
|
|||
"Document {} not found in cache. Fetching from CDN",
|
||||
document_id
|
||||
);
|
||||
let cdn_document_id = str::replace(document_id, "~", "/");
|
||||
let cdn_artifact_url = format!("{}/apps/{}", &self.endpoint, cdn_document_id);
|
||||
info!(
|
||||
"Fetching document {} from CDN: {}",
|
||||
document_id, cdn_artifact_url
|
||||
);
|
||||
let cdn_response = self.agent.get(cdn_artifact_url).send().await;
|
||||
|
||||
match cdn_response {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
let document = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
|
||||
debug!(
|
||||
"Document fetched from CDN: {}, storing in local cache",
|
||||
document
|
||||
);
|
||||
self.cache
|
||||
.insert(document_id.into(), document.clone())
|
||||
.await;
|
||||
|
||||
return Ok(document);
|
||||
let mut last_error: Option<PersistedDocumentsError> = None;
|
||||
for (endpoint, circuit_breaker) in &self.endpoints_with_circuit_breakers {
|
||||
let result = self
|
||||
.resolve_from_endpoint(endpoint, document_id, circuit_breaker)
|
||||
.await;
|
||||
match result {
|
||||
Ok(document) => return Ok(document),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Document fetch from CDN failed: HTTP {}, Body: {:?}",
|
||||
response.status(),
|
||||
response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unavailable".to_string())
|
||||
);
|
||||
|
||||
Err(PersistedDocumentsError::DocumentNotFound)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch document from CDN: {:?}", e);
|
||||
|
||||
Err(PersistedDocumentsError::FailedToFetchFromCDN(e))
|
||||
}
|
||||
}
|
||||
match last_error {
|
||||
Some(e) => Err(e),
|
||||
None => Err(PersistedDocumentsError::Unknown),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PersistedDocumentsManagerBuilder {
|
||||
key: Option<String>,
|
||||
endpoints: Vec<String>,
|
||||
accept_invalid_certs: bool,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
retry_policy: ExponentialBackoff,
|
||||
cache_size: u64,
|
||||
user_agent: Option<String>,
|
||||
circuit_breaker: CircuitBreakerBuilder,
|
||||
}
|
||||
|
||||
impl Default for PersistedDocumentsManagerBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
key: None,
|
||||
endpoints: vec![],
|
||||
accept_invalid_certs: false,
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
request_timeout: Duration::from_secs(15),
|
||||
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
|
||||
cache_size: 10_000,
|
||||
user_agent: None,
|
||||
circuit_breaker: CircuitBreakerBuilder::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PersistedDocumentsManagerBuilder {
|
||||
/// The CDN Access Token with from the Hive Console target.
|
||||
pub fn key(mut self, key: String) -> Self {
|
||||
self.key = non_empty_string(Some(key));
|
||||
self
|
||||
}
|
||||
|
||||
/// The CDN endpoint from Hive Console target.
|
||||
pub fn add_endpoint(mut self, endpoint: String) -> Self {
|
||||
if let Some(endpoint) = non_empty_string(Some(endpoint)) {
|
||||
self.endpoints.push(endpoint);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Accept invalid SSL certificates
|
||||
/// default: false
|
||||
pub fn accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
|
||||
self.accept_invalid_certs = accept_invalid_certs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Connection timeout for the Hive Console CDN requests.
|
||||
/// Default: 5 seconds
|
||||
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
|
||||
self.connect_timeout = connect_timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Request timeout for the Hive Console CDN requests.
|
||||
/// Default: 15 seconds
|
||||
pub fn request_timeout(mut self, request_timeout: Duration) -> Self {
|
||||
self.request_timeout = request_timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Retry policy for fetching persisted documents
|
||||
/// Default: ExponentialBackoff with max 3 retries
|
||||
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
|
||||
self.retry_policy = retry_policy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Maximum number of retries for fetching persisted documents
|
||||
/// Default: ExponentialBackoff with max 3 retries
|
||||
pub fn max_retries(mut self, max_retries: u32) -> Self {
|
||||
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
|
||||
self
|
||||
}
|
||||
|
||||
/// Size of the in-memory cache for persisted documents
|
||||
/// Default: 10,000 entries
|
||||
pub fn cache_size(mut self, cache_size: u64) -> Self {
|
||||
self.cache_size = cache_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// User-Agent header to be sent with each request
|
||||
pub fn user_agent(mut self, user_agent: String) -> Self {
|
||||
self.user_agent = non_empty_string(Some(user_agent));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<PersistedDocumentsManager, PersistedDocumentsError> {
|
||||
let mut default_headers = HeaderMap::new();
|
||||
let key = match self.key {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
return Err(PersistedDocumentsError::MissingConfigurationOption(
|
||||
"key".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
default_headers.insert(
|
||||
"X-Hive-CDN-Key",
|
||||
HeaderValue::from_str(&key)
|
||||
.map_err(|e| PersistedDocumentsError::InvalidCDNKey(e.to_string()))?,
|
||||
);
|
||||
let mut reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(self.accept_invalid_certs)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.timeout(self.request_timeout)
|
||||
.default_headers(default_headers);
|
||||
|
||||
if let Some(user_agent) = self.user_agent {
|
||||
reqwest_agent = reqwest_agent.user_agent(user_agent);
|
||||
}
|
||||
|
||||
let reqwest_agent = reqwest_agent
|
||||
.build()
|
||||
.map_err(PersistedDocumentsError::HTTPClientCreationError)?;
|
||||
let client = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
|
||||
.build();
|
||||
|
||||
let cache = Cache::<String, String>::new(self.cache_size);
|
||||
|
||||
if self.endpoints.is_empty() {
|
||||
return Err(PersistedDocumentsError::MissingConfigurationOption(
|
||||
"endpoints".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(PersistedDocumentsManager {
|
||||
client,
|
||||
cache,
|
||||
endpoints_with_circuit_breakers: self
|
||||
.endpoints
|
||||
.into_iter()
|
||||
.map(move |endpoint| {
|
||||
let circuit_breaker = self
|
||||
.circuit_breaker
|
||||
.clone()
|
||||
.build_async()
|
||||
.map_err(PersistedDocumentsError::CircuitBreakerCreationError)?;
|
||||
Ok((endpoint, circuit_breaker))
|
||||
})
|
||||
.collect::<Result<Vec<(String, AsyncRecloser)>, PersistedDocumentsError>>()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,260 +0,0 @@
|
|||
use std::fmt::Display;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderValue;
|
||||
use reqwest::header::InvalidHeaderValue;
|
||||
use reqwest::header::IF_NONE_MATCH;
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use reqwest_retry::policies::ExponentialBackoff;
|
||||
use reqwest_retry::RetryDecision;
|
||||
use reqwest_retry::RetryPolicy;
|
||||
use reqwest_retry::RetryTransientMiddleware;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcher<AsyncOrSync> {
|
||||
client: SupergraphFetcherAsyncOrSyncClient,
|
||||
endpoint: String,
|
||||
etag: RwLock<Option<HeaderValue>>,
|
||||
state: std::marker::PhantomData<AsyncOrSync>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcherAsyncState;
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcherSyncState;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum SupergraphFetcherAsyncOrSyncClient {
|
||||
Async {
|
||||
reqwest_client: ClientWithMiddleware,
|
||||
},
|
||||
Sync {
|
||||
reqwest_client: reqwest::blocking::Client,
|
||||
retry_policy: ExponentialBackoff,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum SupergraphFetcherError {
|
||||
FetcherCreationError(reqwest::Error),
|
||||
NetworkError(reqwest_middleware::Error),
|
||||
NetworkResponseError(reqwest::Error),
|
||||
Lock(String),
|
||||
InvalidKey(InvalidHeaderValue),
|
||||
}
|
||||
|
||||
impl Display for SupergraphFetcherError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SupergraphFetcherError::FetcherCreationError(e) => {
|
||||
write!(f, "Creating fetcher failed: {}", e)
|
||||
}
|
||||
SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e),
|
||||
SupergraphFetcherError::NetworkResponseError(e) => {
|
||||
write!(f, "Network response error: {}", e)
|
||||
}
|
||||
SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e),
|
||||
SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_client_config(
|
||||
mut endpoint: String,
|
||||
key: &str,
|
||||
retry_count: u32,
|
||||
) -> Result<(String, HeaderMap, ExponentialBackoff), SupergraphFetcherError> {
|
||||
if !endpoint.ends_with("/supergraph") {
|
||||
if endpoint.ends_with("/") {
|
||||
endpoint.push_str("supergraph");
|
||||
} else {
|
||||
endpoint.push_str("/supergraph");
|
||||
}
|
||||
}
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
let mut cdn_key_header =
|
||||
HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?;
|
||||
cdn_key_header.set_sensitive(true);
|
||||
headers.insert("X-Hive-CDN-Key", cdn_key_header);
|
||||
|
||||
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count);
|
||||
|
||||
Ok((endpoint, headers, retry_policy))
|
||||
}
|
||||
|
||||
impl SupergraphFetcher<SupergraphFetcherSyncState> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn try_new_sync(
|
||||
endpoint: String,
|
||||
key: &str,
|
||||
user_agent: String,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
accept_invalid_certs: bool,
|
||||
retry_count: u32,
|
||||
) -> Result<Self, SupergraphFetcherError> {
|
||||
let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
|
||||
|
||||
Ok(Self {
|
||||
client: SupergraphFetcherAsyncOrSyncClient::Sync {
|
||||
reqwest_client: reqwest::blocking::Client::builder()
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.connect_timeout(connect_timeout)
|
||||
.timeout(request_timeout)
|
||||
.user_agent(user_agent)
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(SupergraphFetcherError::FetcherCreationError)?,
|
||||
retry_policy,
|
||||
},
|
||||
endpoint,
|
||||
etag: RwLock::new(None),
|
||||
state: std::marker::PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
|
||||
let request_start_time = SystemTime::now();
|
||||
// Implementing retry logic for sync client
|
||||
let mut n_past_retries = 0;
|
||||
let (reqwest_client, retry_policy) = match &self.client {
|
||||
SupergraphFetcherAsyncOrSyncClient::Sync {
|
||||
reqwest_client,
|
||||
retry_policy,
|
||||
} => (reqwest_client, retry_policy),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let resp = loop {
|
||||
let mut req = reqwest_client.get(&self.endpoint);
|
||||
let etag = self.get_latest_etag()?;
|
||||
if let Some(etag) = etag {
|
||||
req = req.header(IF_NONE_MATCH, etag);
|
||||
}
|
||||
let response = req.send();
|
||||
|
||||
match response {
|
||||
Ok(resp) => break resp,
|
||||
Err(e) => match retry_policy.should_retry(request_start_time, n_past_retries) {
|
||||
RetryDecision::DoNotRetry => {
|
||||
return Err(SupergraphFetcherError::NetworkError(
|
||||
reqwest_middleware::Error::Reqwest(e),
|
||||
));
|
||||
}
|
||||
RetryDecision::Retry { execute_after } => {
|
||||
n_past_retries += 1;
|
||||
if let Ok(duration) = execute_after.elapsed() {
|
||||
std::thread::sleep(duration);
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
if resp.status().as_u16() == 304 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let etag = resp.headers().get("etag");
|
||||
self.update_latest_etag(etag)?;
|
||||
|
||||
let text = resp
|
||||
.text()
|
||||
.map_err(SupergraphFetcherError::NetworkResponseError)?;
|
||||
|
||||
Ok(Some(text))
|
||||
}
|
||||
}
|
||||
|
||||
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn try_new_async(
|
||||
endpoint: String,
|
||||
key: &str,
|
||||
user_agent: String,
|
||||
connect_timeout: Duration,
|
||||
request_timeout: Duration,
|
||||
accept_invalid_certs: bool,
|
||||
retry_count: u32,
|
||||
) -> Result<Self, SupergraphFetcherError> {
|
||||
let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
|
||||
|
||||
let reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.connect_timeout(connect_timeout)
|
||||
.timeout(request_timeout)
|
||||
.default_headers(headers)
|
||||
.user_agent(user_agent)
|
||||
.build()
|
||||
.map_err(SupergraphFetcherError::FetcherCreationError)?;
|
||||
let reqwest_client = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
|
||||
.build();
|
||||
|
||||
Ok(Self {
|
||||
client: SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client },
|
||||
endpoint,
|
||||
etag: RwLock::new(None),
|
||||
state: std::marker::PhantomData,
|
||||
})
|
||||
}
|
||||
pub async fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
|
||||
let reqwest_client = match &self.client {
|
||||
SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client } => reqwest_client,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let mut req = reqwest_client.get(&self.endpoint);
|
||||
let etag = self.get_latest_etag()?;
|
||||
if let Some(etag) = etag {
|
||||
req = req.header(IF_NONE_MATCH, etag);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(SupergraphFetcherError::NetworkError)?;
|
||||
|
||||
if resp.status().as_u16() == 304 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let etag = resp.headers().get("etag");
|
||||
self.update_latest_etag(etag)?;
|
||||
|
||||
let text = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(SupergraphFetcherError::NetworkResponseError)?;
|
||||
|
||||
Ok(Some(text))
|
||||
}
|
||||
}
|
||||
|
||||
impl<AsyncOrSync> SupergraphFetcher<AsyncOrSync> {
|
||||
fn get_latest_etag(&self) -> Result<Option<HeaderValue>, SupergraphFetcherError> {
|
||||
let guard: std::sync::RwLockReadGuard<'_, Option<HeaderValue>> =
|
||||
self.etag.try_read().map_err(|e| {
|
||||
SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(guard.clone())
|
||||
}
|
||||
|
||||
fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> {
|
||||
let mut guard: std::sync::RwLockWriteGuard<'_, Option<HeaderValue>> =
|
||||
self.etag.try_write().map_err(|e| {
|
||||
SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e))
|
||||
})?;
|
||||
|
||||
if let Some(etag_value) = etag {
|
||||
*guard = Some(etag_value.clone());
|
||||
} else {
|
||||
*guard = None;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
141
packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs
Normal file
141
packages/libraries/sdk-rs/src/supergraph_fetcher/async_.rs
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
use futures_util::TryFutureExt;
|
||||
use recloser::AsyncRecloser;
|
||||
use reqwest::header::{HeaderValue, IF_NONE_MATCH};
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
|
||||
use reqwest_retry::RetryTransientMiddleware;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::supergraph_fetcher::{
|
||||
builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcherAsyncState {
|
||||
endpoints_with_circuit_breakers: Vec<(String, AsyncRecloser)>,
|
||||
reqwest_client: ClientWithMiddleware,
|
||||
}
|
||||
|
||||
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
|
||||
pub async fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
|
||||
let mut last_error: Option<SupergraphFetcherError> = None;
|
||||
let mut last_resp = None;
|
||||
for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers {
|
||||
let mut req = self.state.reqwest_client.get(endpoint);
|
||||
let etag = self.get_latest_etag().await;
|
||||
if let Some(etag) = etag {
|
||||
req = req.header(IF_NONE_MATCH, etag);
|
||||
}
|
||||
let resp_fut = async {
|
||||
let mut resp = req.send().await.map_err(SupergraphFetcherError::Network);
|
||||
// Server errors (5xx) are considered errors
|
||||
if let Ok(ok_res) = resp {
|
||||
resp = if ok_res.status().is_server_error() {
|
||||
return Err(SupergraphFetcherError::Network(
|
||||
reqwest_middleware::Error::Middleware(anyhow::anyhow!(
|
||||
"Server error: {}",
|
||||
ok_res.status()
|
||||
)),
|
||||
));
|
||||
} else {
|
||||
Ok(ok_res)
|
||||
}
|
||||
}
|
||||
resp
|
||||
};
|
||||
let resp = circuit_breaker
|
||||
.call(resp_fut)
|
||||
// Map recloser errors to SupergraphFetcherError
|
||||
.map_err(|e| match e {
|
||||
recloser::Error::Inner(e) => e,
|
||||
recloser::Error::Rejected => SupergraphFetcherError::RejectedByCircuitBreaker,
|
||||
})
|
||||
.await;
|
||||
match resp {
|
||||
Err(err) => {
|
||||
last_error = Some(err);
|
||||
continue;
|
||||
}
|
||||
Ok(resp) => {
|
||||
last_resp = Some(resp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(last_resp) = last_resp {
|
||||
let etag = last_resp.headers().get("etag");
|
||||
self.update_latest_etag(etag).await;
|
||||
let text = last_resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(SupergraphFetcherError::ResponseParse)?;
|
||||
Ok(Some(text))
|
||||
} else if let Some(error) = last_error {
|
||||
Err(error)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
async fn get_latest_etag(&self) -> Option<HeaderValue> {
|
||||
let guard = self.etag.read().await;
|
||||
|
||||
guard.clone()
|
||||
}
|
||||
async fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> () {
|
||||
let mut guard = self.etag.write().await;
|
||||
|
||||
if let Some(etag_value) = etag {
|
||||
*guard = Some(etag_value.clone());
|
||||
} else {
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SupergraphFetcherBuilder {
|
||||
/// Builds an asynchronous SupergraphFetcher
|
||||
pub fn build_async(
|
||||
self,
|
||||
) -> Result<SupergraphFetcher<SupergraphFetcherAsyncState>, SupergraphFetcherError> {
|
||||
self.validate_endpoints()?;
|
||||
|
||||
let headers = self.prepare_headers()?;
|
||||
|
||||
let mut reqwest_agent = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(self.accept_invalid_certs)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.timeout(self.request_timeout)
|
||||
.default_headers(headers);
|
||||
|
||||
if let Some(user_agent) = self.user_agent {
|
||||
reqwest_agent = reqwest_agent.user_agent(user_agent);
|
||||
}
|
||||
|
||||
let reqwest_agent = reqwest_agent
|
||||
.build()
|
||||
.map_err(SupergraphFetcherError::HTTPClientCreation)?;
|
||||
let reqwest_client = ClientBuilder::new(reqwest_agent)
|
||||
.with(RetryTransientMiddleware::new_with_policy(self.retry_policy))
|
||||
.build();
|
||||
|
||||
Ok(SupergraphFetcher {
|
||||
state: SupergraphFetcherAsyncState {
|
||||
reqwest_client,
|
||||
endpoints_with_circuit_breakers: self
|
||||
.endpoints
|
||||
.into_iter()
|
||||
.map(|endpoint| {
|
||||
let circuit_breaker = self
|
||||
.circuit_breaker
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.build_async()
|
||||
.map_err(SupergraphFetcherError::CircuitBreakerCreation);
|
||||
circuit_breaker.map(|cb| (endpoint, cb))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
},
|
||||
etag: RwLock::new(None),
|
||||
})
|
||||
}
|
||||
}
|
||||
135
packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs
Normal file
135
packages/libraries/sdk-rs/src/supergraph_fetcher/builder.rs
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use reqwest::header::{HeaderMap, HeaderValue};
|
||||
use retry_policies::policies::ExponentialBackoff;
|
||||
|
||||
use crate::{
|
||||
agent::usage_agent::non_empty_string, circuit_breaker::CircuitBreakerBuilder,
|
||||
supergraph_fetcher::SupergraphFetcherError,
|
||||
};
|
||||
|
||||
pub struct SupergraphFetcherBuilder {
|
||||
pub(crate) endpoints: Vec<String>,
|
||||
pub(crate) key: Option<String>,
|
||||
pub(crate) user_agent: Option<String>,
|
||||
pub(crate) connect_timeout: Duration,
|
||||
pub(crate) request_timeout: Duration,
|
||||
pub(crate) accept_invalid_certs: bool,
|
||||
pub(crate) retry_policy: ExponentialBackoff,
|
||||
pub(crate) circuit_breaker: Option<CircuitBreakerBuilder>,
|
||||
}
|
||||
|
||||
impl Default for SupergraphFetcherBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
endpoints: vec![],
|
||||
key: None,
|
||||
user_agent: None,
|
||||
connect_timeout: Duration::from_secs(5),
|
||||
request_timeout: Duration::from_secs(60),
|
||||
accept_invalid_certs: false,
|
||||
retry_policy: ExponentialBackoff::builder().build_with_max_retries(3),
|
||||
circuit_breaker: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SupergraphFetcherBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// The CDN endpoint from Hive Console target.
|
||||
pub fn add_endpoint(mut self, endpoint: String) -> Self {
|
||||
if let Some(mut endpoint) = non_empty_string(Some(endpoint)) {
|
||||
if !endpoint.ends_with("/supergraph") {
|
||||
if endpoint.ends_with("/") {
|
||||
endpoint.push_str("supergraph");
|
||||
} else {
|
||||
endpoint.push_str("/supergraph");
|
||||
}
|
||||
}
|
||||
self.endpoints.push(endpoint);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// The CDN Access Token with from the Hive Console target.
|
||||
pub fn key(mut self, key: String) -> Self {
|
||||
self.key = Some(key);
|
||||
self
|
||||
}
|
||||
|
||||
/// User-Agent header to be sent with each request
|
||||
pub fn user_agent(mut self, user_agent: String) -> Self {
|
||||
self.user_agent = Some(user_agent);
|
||||
self
|
||||
}
|
||||
|
||||
/// Connection timeout for the Hive Console CDN requests.
|
||||
/// Default: 5 seconds
|
||||
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.connect_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Request timeout for the Hive Console CDN requests.
|
||||
/// Default: 60 seconds
|
||||
pub fn request_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.request_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn accept_invalid_certs(mut self, accept: bool) -> Self {
|
||||
self.accept_invalid_certs = accept;
|
||||
self
|
||||
}
|
||||
|
||||
/// Policy for retrying failed requests.
|
||||
///
|
||||
/// By default, an exponential backoff retry policy is used, with 10 attempts.
|
||||
pub fn retry_policy(mut self, retry_policy: ExponentialBackoff) -> Self {
|
||||
self.retry_policy = retry_policy;
|
||||
self
|
||||
}
|
||||
|
||||
/// Maximum number of retries for failed requests.
|
||||
///
|
||||
/// By default, an exponential backoff retry policy is used, with 10 attempts.
|
||||
pub fn max_retries(mut self, max_retries: u32) -> Self {
|
||||
self.retry_policy = ExponentialBackoff::builder().build_with_max_retries(max_retries);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn circuit_breaker(&mut self, builder: CircuitBreakerBuilder) -> &mut Self {
|
||||
self.circuit_breaker = Some(builder);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn validate_endpoints(&self) -> Result<(), SupergraphFetcherError> {
|
||||
if self.endpoints.is_empty() {
|
||||
return Err(SupergraphFetcherError::MissingConfigurationOption(
|
||||
"endpoint".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn prepare_headers(&self) -> Result<HeaderMap, SupergraphFetcherError> {
|
||||
let key = match &self.key {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
return Err(SupergraphFetcherError::MissingConfigurationOption(
|
||||
"key".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
let mut cdn_key_header =
|
||||
HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?;
|
||||
cdn_key_header.set_sensitive(true);
|
||||
headers.insert("X-Hive-CDN-Key", cdn_key_header);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
}
|
||||
51
packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs
Normal file
51
packages/libraries/sdk-rs/src/supergraph_fetcher/mod.rs
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
use tokio::sync::RwLock;
|
||||
use tokio::sync::TryLockError;
|
||||
|
||||
use crate::circuit_breaker::CircuitBreakerError;
|
||||
use crate::supergraph_fetcher::async_::SupergraphFetcherAsyncState;
|
||||
use reqwest::header::HeaderValue;
|
||||
use reqwest::header::InvalidHeaderValue;
|
||||
|
||||
pub mod async_;
|
||||
pub mod builder;
|
||||
pub mod sync;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcher<State> {
|
||||
state: State,
|
||||
etag: RwLock<Option<HeaderValue>>,
|
||||
}
|
||||
|
||||
// Doesn't matter which one we implement this for, both have the same builder
|
||||
impl SupergraphFetcher<SupergraphFetcherAsyncState> {
|
||||
pub fn builder() -> builder::SupergraphFetcherBuilder {
|
||||
builder::SupergraphFetcherBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum LockErrorType {
|
||||
Read,
|
||||
Write,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SupergraphFetcherError {
|
||||
#[error("Creating HTTP Client failed: {0}")]
|
||||
HTTPClientCreation(reqwest::Error),
|
||||
#[error("Network error: {0}")]
|
||||
Network(reqwest_middleware::Error),
|
||||
#[error("Parsing response failed: {0}")]
|
||||
ResponseParse(reqwest::Error),
|
||||
#[error("Reading the etag record failed: {0:?}")]
|
||||
ETagRead(TryLockError),
|
||||
#[error("Updating the etag record failed: {0:?}")]
|
||||
ETagWrite(TryLockError),
|
||||
#[error("Invalid CDN key: {0}")]
|
||||
InvalidKey(InvalidHeaderValue),
|
||||
#[error("Missing configuration option: {0}")]
|
||||
MissingConfigurationOption(String),
|
||||
#[error("Request rejected by circuit breaker")]
|
||||
RejectedByCircuitBreaker,
|
||||
#[error("Creating circuit breaker failed: {0}")]
|
||||
CircuitBreakerCreation(CircuitBreakerError),
|
||||
}
|
||||
191
packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs
Normal file
191
packages/libraries/sdk-rs/src/supergraph_fetcher/sync.rs
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
use std::time::SystemTime;
|
||||
|
||||
use recloser::Recloser;
|
||||
use reqwest::header::{HeaderValue, IF_NONE_MATCH};
|
||||
use reqwest_retry::{RetryDecision, RetryPolicy};
|
||||
use retry_policies::policies::ExponentialBackoff;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::supergraph_fetcher::{
|
||||
builder::SupergraphFetcherBuilder, SupergraphFetcher, SupergraphFetcherError,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SupergraphFetcherSyncState {
|
||||
endpoints_with_circuit_breakers: Vec<(String, Recloser)>,
|
||||
reqwest_client: reqwest::blocking::Client,
|
||||
retry_policy: ExponentialBackoff,
|
||||
}
|
||||
|
||||
impl SupergraphFetcher<SupergraphFetcherSyncState> {
|
||||
pub fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
|
||||
let mut last_error: Option<SupergraphFetcherError> = None;
|
||||
let mut last_resp = None;
|
||||
for (endpoint, circuit_breaker) in &self.state.endpoints_with_circuit_breakers {
|
||||
let resp = {
|
||||
circuit_breaker
|
||||
.call(|| {
|
||||
let request_start_time = SystemTime::now();
|
||||
// Implementing retry logic for sync client
|
||||
let mut n_past_retries = 0;
|
||||
loop {
|
||||
let mut req = self.state.reqwest_client.get(endpoint);
|
||||
let etag = self.get_latest_etag()?;
|
||||
if let Some(etag) = etag {
|
||||
req = req.header(IF_NONE_MATCH, etag);
|
||||
}
|
||||
let mut response = req.send().map_err(|err| {
|
||||
SupergraphFetcherError::Network(reqwest_middleware::Error::Reqwest(
|
||||
err,
|
||||
))
|
||||
});
|
||||
|
||||
// Server errors (5xx) are considered retryable
|
||||
if let Ok(ok_res) = response {
|
||||
response = if ok_res.status().is_server_error() {
|
||||
Err(SupergraphFetcherError::Network(
|
||||
reqwest_middleware::Error::Middleware(anyhow::anyhow!(
|
||||
"Server error: {}",
|
||||
ok_res.status()
|
||||
)),
|
||||
))
|
||||
} else {
|
||||
Ok(ok_res)
|
||||
}
|
||||
}
|
||||
|
||||
match response {
|
||||
Ok(resp) => break Ok(resp),
|
||||
Err(e) => {
|
||||
match self
|
||||
.state
|
||||
.retry_policy
|
||||
.should_retry(request_start_time, n_past_retries)
|
||||
{
|
||||
RetryDecision::DoNotRetry => {
|
||||
return Err(e);
|
||||
}
|
||||
RetryDecision::Retry { execute_after } => {
|
||||
n_past_retries += 1;
|
||||
match execute_after.elapsed() {
|
||||
Ok(duration) => {
|
||||
std::thread::sleep(duration);
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!(
|
||||
"Error determining sleep duration for retry: {}",
|
||||
err
|
||||
);
|
||||
// If elapsed time cannot be determined, do not wait
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
// Map recloser errors to SupergraphFetcherError
|
||||
.map_err(|e| match e {
|
||||
recloser::Error::Inner(e) => e,
|
||||
recloser::Error::Rejected => {
|
||||
SupergraphFetcherError::RejectedByCircuitBreaker
|
||||
}
|
||||
})
|
||||
};
|
||||
match resp {
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
continue;
|
||||
}
|
||||
Ok(resp) => {
|
||||
last_resp = Some(resp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(last_resp) = last_resp {
|
||||
if last_resp.status().as_u16() == 304 {
|
||||
return Ok(None);
|
||||
}
|
||||
self.update_latest_etag(last_resp.headers().get("etag"))?;
|
||||
let text = last_resp
|
||||
.text()
|
||||
.map_err(SupergraphFetcherError::ResponseParse)?;
|
||||
Ok(Some(text))
|
||||
} else if let Some(error) = last_error {
|
||||
Err(error)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
fn get_latest_etag(&self) -> Result<Option<HeaderValue>, SupergraphFetcherError> {
|
||||
let guard = self
|
||||
.etag
|
||||
.try_read()
|
||||
.map_err(SupergraphFetcherError::ETagRead)?;
|
||||
|
||||
Ok(guard.clone())
|
||||
}
|
||||
fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> {
|
||||
let mut guard = self
|
||||
.etag
|
||||
.try_write()
|
||||
.map_err(SupergraphFetcherError::ETagWrite)?;
|
||||
|
||||
if let Some(etag_value) = etag {
|
||||
*guard = Some(etag_value.clone());
|
||||
} else {
|
||||
*guard = None;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SupergraphFetcherBuilder {
|
||||
/// Builds a synchronous SupergraphFetcher
|
||||
pub fn build_sync(
|
||||
self,
|
||||
) -> Result<SupergraphFetcher<SupergraphFetcherSyncState>, SupergraphFetcherError> {
|
||||
self.validate_endpoints()?;
|
||||
let headers = self.prepare_headers()?;
|
||||
|
||||
let mut reqwest_client = reqwest::blocking::Client::builder()
|
||||
.danger_accept_invalid_certs(self.accept_invalid_certs)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.timeout(self.request_timeout)
|
||||
.default_headers(headers);
|
||||
|
||||
if let Some(user_agent) = &self.user_agent {
|
||||
reqwest_client = reqwest_client.user_agent(user_agent);
|
||||
}
|
||||
|
||||
let reqwest_client = reqwest_client
|
||||
.build()
|
||||
.map_err(SupergraphFetcherError::HTTPClientCreation)?;
|
||||
let fetcher = SupergraphFetcher {
|
||||
state: SupergraphFetcherSyncState {
|
||||
reqwest_client,
|
||||
retry_policy: self.retry_policy,
|
||||
endpoints_with_circuit_breakers: self
|
||||
.endpoints
|
||||
.into_iter()
|
||||
.map(|endpoint| {
|
||||
let circuit_breaker = self
|
||||
.circuit_breaker
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.build_sync()
|
||||
.map_err(SupergraphFetcherError::CircuitBreakerCreation);
|
||||
circuit_breaker.map(|cb| (endpoint, cb))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
},
|
||||
etag: RwLock::new(None),
|
||||
};
|
||||
Ok(fetcher)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue