fix(apollo-router-hive-plugin): update apollo-router to v2 (#6549)

Co-authored-by: Yann Simon <yann.simon.fr@gmail.com>
This commit is contained in:
Dotan Simha 2025-02-24 12:26:38 +02:00 committed by GitHub
parent eb903c0775
commit 158b63b4f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 955 additions and 604 deletions

View file

@ -0,0 +1,5 @@
---
'hive-apollo-router-plugin': major
---
Updated core dependnecies (body, http) to match apollo-router v2

View file

@ -0,0 +1,5 @@
---
'hive-apollo-router-plugin': patch
---
Updated thiserror, jsonschema, lru, rand to latest and adjust the code

1423
configs/cargo/Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -17,38 +17,40 @@ name = "hive_apollo_router_plugin"
path = "src/lib.rs"
[dependencies]
apollo-router = { version = "^1.13.0" }
thiserror = "1.0.62"
apollo-router = { version = "^2.0.0" }
axum-core = "0.5"
thiserror = "2.0.11"
reqwest = { version = "0.12.0", default-features = false, features = [
"rustls-tls",
"blocking",
"json",
"rustls-tls",
"blocking",
"json",
] }
reqwest-retry = "0.7.0"
reqwest-middleware = "0.4.0"
sha2 = { version = "0.10.8", features = ["std"] }
anyhow = "1"
tracing = "0.1"
hyper = { version = "0.14.28", features = ["server", "client", "stream"] }
hyper = { version = "1", features = ["server", "client"] }
async-trait = "0.1.77"
futures = { version = "0.3.30", features = ["thread-pool"] }
schemars = { version = "0.8", features = ["url"] }
serde = "1"
serde_json = "1"
tokio = { version = "1.36.0", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }
http = "0.2"
tower = { version = "0.5", features = ["full"] }
http = "1"
http-body-util = "0.1"
graphql-parser = { version = "0.5.0", package = "graphql-parser-hive-fork" }
graphql-tools = { version = "0.4.0", features = [
"graphql_parser_fork",
"graphql_parser_fork",
], default-features = false }
lru = "^0.12.1"
lru = "^0.13.0"
md5 = "0.7.0"
rand = "0.8.5"
rand = "0.9.0"
[dev-dependencies]
httpmock = "0.7.0"
jsonschema = { version = "0.26.1", default-features = false, features = [
"resolve-file",
jsonschema = { version = "0.29.0", default-features = false, features = [
"resolve-file",
] }
lazy_static = "1.5.0"

View file

@ -258,7 +258,7 @@ impl<'a> OperationVisitor<'a, SchemaCoordinatesContext> for SchemaCoordinatesVis
Value::List(_) => {
// handled by enter_list_value
}
Value::Object(a) => {
Value::Object(_a) => {
// handled by enter_object_field
}
Value::Variable(_) => {
@ -278,7 +278,7 @@ impl<'a> OperationVisitor<'a, SchemaCoordinatesContext> for SchemaCoordinatesVis
&mut self,
info: &mut OperationVisitorContext<'a>,
ctx: &mut SchemaCoordinatesContext,
object_field: &(String, graphql_tools::static_graphql::query::Value),
_object_field: &(String, graphql_tools::static_graphql::query::Value),
) {
if let Some(input_type) = info.current_input_type() {
match input_type {

View file

@ -7,12 +7,18 @@ use apollo_router::services::router;
use apollo_router::services::router::Body;
use apollo_router::Context;
use core::ops::Drop;
use std::env;
use futures::FutureExt;
use http::StatusCode;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::Bytes;
use lru::LruCache;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::env;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::sync::Arc;
@ -20,8 +26,6 @@ use std::time::Duration;
use tokio::sync::Mutex;
use tower::{BoxError, ServiceBuilder, ServiceExt};
use tracing::{debug, info, warn};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
pub struct PersistedDocumentsPlugin {
persisted_documents_manager: Arc<PersistedDocumentsManager>,
@ -103,14 +107,15 @@ impl Plugin for PersistedDocumentsPlugin {
if enabled {
ServiceBuilder::new()
.oneshot_checkpoint_async(move |req: router::Request| {
.checkpoint_async(move |req: router::Request| {
let mgr = mgr_ref.clone();
async move {
let (parts, body) = req.router_request.into_parts();
let bytes: hyper::body::Bytes = hyper::body::to_bytes(body)
let bytes: hyper::body::Bytes = body
.collect()
.await
.map_err(PersistedDocumentsError::FailedToReadBody)?;
.map_err(PersistedDocumentsError::FailedToReadBody)?
.to_bytes();
let payload = PersistedDocumentsManager::extract_document_id(&bytes);
@ -124,7 +129,10 @@ impl Plugin for PersistedDocumentsPlugin {
if payload.original_req.query.is_some() {
if allow_arbitrary_documents {
let roll_req: router::Request = (
http::Request::<Body>::from_parts(parts, bytes.into()),
http::Request::<Body>::from_parts(
parts,
body_from_bytes(bytes),
),
req.context,
)
.into();
@ -163,14 +171,14 @@ impl Plugin for PersistedDocumentsPlugin {
{
warn!("failed to extend router context with persisted document hash key");
}
payload.original_req.query = Some(document);
let mut bytes: Vec<u8> = Vec::new();
serde_json::to_writer(&mut bytes, &payload).unwrap();
let roll_req: router::Request = (
http::Request::<Body>::from_parts(parts, bytes.into()),
http::Request::<Body>::from_parts(parts, body_from_bytes(bytes)),
req.context,
)
.into();
@ -187,6 +195,7 @@ impl Plugin for PersistedDocumentsPlugin {
}
.boxed()
})
.buffered()
.service(service)
.boxed()
} else {
@ -195,6 +204,12 @@ impl Plugin for PersistedDocumentsPlugin {
}
}
fn body_from_bytes<T: Into<Bytes>>(chunk: T) -> UnsyncBoxBody<Bytes, axum_core::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed_unsync()
}
impl Drop for PersistedDocumentsPlugin {
fn drop(&mut self) {
debug!("PersistedDocumentsPlugin has been dropped!");
@ -211,7 +226,7 @@ struct PersistedDocumentsManager {
#[derive(Debug, thiserror::Error)]
pub enum PersistedDocumentsError {
#[error("Failed to read body: {0}")]
FailedToReadBody(hyper::Error),
FailedToReadBody(axum_core::Error),
#[error("Failed to parse body: {0}")]
FailedToParseBody(serde_json::Error),
#[error("Persisted document not found.")]
@ -238,7 +253,9 @@ impl PersistedDocumentsError {
PersistedDocumentsError::DocumentNotFound => "PERSISTED_DOCUMENT_NOT_FOUND".into(),
PersistedDocumentsError::KeyNotFound => "PERSISTED_DOCUMENT_KEY_NOT_FOUND".into(),
PersistedDocumentsError::FailedToFetchFromCDN(_) => "FAILED_TO_FETCH_FROM_CDN".into(),
PersistedDocumentsError::FailedToReadCDNResponse(_) => "FAILED_TO_READ_CDN_RESPONSE".into(),
PersistedDocumentsError::FailedToReadCDNResponse(_) => {
"FAILED_TO_READ_CDN_RESPONSE".into()
}
PersistedDocumentsError::PersistedDocumentRequired => {
"PERSISTED_DOCUMENT_REQUIRED".into()
}
@ -262,13 +279,17 @@ impl PersistedDocumentsError {
impl PersistedDocumentsManager {
fn new(config: &Config) -> Self {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(config.retry_count.unwrap_or(3));
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(config.retry_count.unwrap_or(3));
let reqwest_agent = reqwest::Client::builder()
.danger_accept_invalid_certs(config.accept_invalid_certs.unwrap_or(false))
.connect_timeout(Duration::from_secs(config.connect_timeout.unwrap_or(5)))
.timeout(Duration::from_secs(config.request_timeout.unwrap_or(15))).build().expect("Failed to create reqwest client");
let agent = ClientBuilder::new(reqwest_agent).with(RetryTransientMiddleware::new_with_policy(retry_policy))
.danger_accept_invalid_certs(config.accept_invalid_certs.unwrap_or(false))
.connect_timeout(Duration::from_secs(config.connect_timeout.unwrap_or(5)))
.timeout(Duration::from_secs(config.request_timeout.unwrap_or(15)))
.build()
.expect("Failed to create reqwest client");
let agent = ClientBuilder::new(reqwest_agent)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let cache_size = config.cache_size.unwrap_or(1000);
@ -285,7 +306,7 @@ impl PersistedDocumentsManager {
/// Extracts the document id from the request body.
/// In case of a parsing error, it returns the error.
/// This will also try to parse other GraphQL-related (see `original_req`) fields in order to
/// This will also try to parse other GraphQL-related (see `original_req`) fields in order to
/// pass it to the next layer.
fn extract_document_id(
body: &hyper::body::Bytes,
@ -310,7 +331,11 @@ impl PersistedDocumentsManager {
document_id
);
let cdn_document_id = str::replace(document_id, "~", "/");
let cdn_artifact_url = format!("{}/apps/{}", self.config.endpoint.as_ref().unwrap(), cdn_document_id);
let cdn_artifact_url = format!(
"{}/apps/{}",
self.config.endpoint.as_ref().unwrap(),
cdn_document_id
);
info!(
"Fetching document {} from CDN: {}",
document_id, cdn_artifact_url
@ -325,7 +350,10 @@ impl PersistedDocumentsManager {
match cdn_response {
Ok(response) => {
if response.status().is_success() {
let document = response.text().await.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
let document = response
.text()
.await
.map_err(PersistedDocumentsError::FailedToReadCDNResponse)?;
debug!(
"Document fetched from CDN: {}, storing in local cache",
document
@ -341,7 +369,10 @@ impl PersistedDocumentsManager {
warn!(
"Document fetch from CDN failed: HTTP {}, Body: {:?}",
response.status(),
response.text().await.unwrap_or_else(|_| "Unavailable".to_string())
response
.text()
.await
.unwrap_or_else(|_| "Unavailable".to_string())
);
return Err(PersistedDocumentsError::DocumentNotFound);
@ -382,6 +413,7 @@ mod hive_persisted_documents_tests {
use futures::executor::block_on;
use http::Method;
use httpmock::{Method::GET, Mock, MockServer};
use hyper::body::Bytes;
use serde_json::json;
use super::*;
@ -395,7 +427,7 @@ mod hive_persisted_documents_tests {
router::Request::fake_builder()
.method(Method::POST)
.body(Body::from(serde_json::to_string(&r).unwrap()))
.body(serde_json::to_string(&r).unwrap())
.header("content-type", "application/json")
.build()
.unwrap()
@ -432,7 +464,6 @@ mod hive_persisted_documents_tests {
.unwrap()
}
struct PersistedDocumentsCDNMock {
server: MockServer,
}
@ -465,13 +496,13 @@ mod hive_persisted_documents_tests {
async fn get_body(router_req: router::Request) -> String {
let (_parts, body) = router_req.router_request.into_parts();
let body = hyper::body::to_bytes(body).await.unwrap();
let body = body.collect().await.unwrap().to_bytes();
String::from_utf8(body.to_vec()).unwrap()
}
/// Creates a mocked router service that reflects the incoming body
/// back to the client.
/// We are using this mocked router in order to make sure that the Persisted Documents layer
/// We are using this mocked router in order to make sure that the Persisted Documents layer
/// is able to resolve, fetch and pass the document to the next layer.
fn create_reflecting_mocked_router() -> MockRouterService {
let mut mocked_execution: MockRouterService = MockRouterService::new();
@ -678,7 +709,7 @@ mod hive_persisted_documents_tests {
let service_stack = PersistedDocumentsPlugin::new(Config {
enabled: Some(true),
endpoint: Some(cdn_mock.endpoint()),
key: Some("123".into()),
key: Some("123".into()),
allow_arbitrary_documents: Some(false),
..Default::default()
})
@ -709,14 +740,14 @@ mod hive_persisted_documents_tests {
let service_stack = PersistedDocumentsPlugin::new(Config {
enabled: Some(true),
endpoint: Some(cdn_mock.endpoint()),
key: Some("123".into()),
key: Some("123".into()),
allow_arbitrary_documents: Some(false),
..Default::default()
})
.router_service(upstream.boxed());
let request = create_persisted_request(
"my-app~cacb95c69ba4684aec972777a38cd106740c6453~04bfa72dfb83b297dd8a5b6fed9bafac2b395a0f",
"my-app~cacb95c69ba4684aec972777a38cd106740c6453~04bfa72dfb83b297dd8a5b6fed9bafac2b395a0f",
Some(json!({"var": "value"}))
);
let mut response = service_stack.oneshot(request).await.unwrap();
@ -737,7 +768,8 @@ mod hive_persisted_documents_tests {
let p = PersistedDocumentsPlugin::new(Config {
enabled: Some(true),
endpoint: Some(cdn_mock.endpoint()),
key: Some("123".into()), allow_arbitrary_documents: Some(false),
key: Some("123".into()),
allow_arbitrary_documents: Some(false),
..Default::default()
});
let s1 = p.router_service(create_dummy_mocked_router().boxed());

View file

@ -136,8 +136,8 @@ impl UsagePlugin {
.into_iter()
.collect();
let mut rng = rand::thread_rng();
let sampled = rng.gen::<f64>() < config.sample_rate;
let mut rng = rand::rng();
let sampled = rng.random::<f64>() < config.sample_rate;
let mut dropped = !sampled;
if !dropped {