cli/src/client.rs
Brandon 991d8f0b15
Fixing issue #830 (#837)
* Fixing issue #830

* fix: fetch template metadata without auth headers

Railway template lookup endpoints reject authenticated requests, so use a public GraphQL client for template detail and template search requests. Add a regression test so add --database keeps working for logged-in users.

* chore: drop unrelated auth changes

Keep the postgres auth fix focused by removing the stray npm lockfile and the unrelated API token handling change from this PR.
2026-04-15 21:40:07 -04:00

316 lines
10 KiB
Rust

use std::time::Duration;
use graphql_client::GraphQLQuery;
use reqwest::{
Client,
header::{HeaderMap, HeaderValue},
};
use crate::{
commands::Environment,
config::Configs,
consts::{self, RAILWAY_API_TOKEN_ENV, RAILWAY_TOKEN_ENV},
errors::RailwayError,
oauth,
};
use anyhow::Result;
use graphql_client::Response as GraphQLResponse;
pub struct GQLClient;
impl GQLClient {
pub fn new_public() -> Result<Client, RailwayError> {
let mut headers = HeaderMap::new();
headers.insert(
"x-source",
HeaderValue::from_static(consts::get_user_agent()),
);
Ok(Self::build_client(headers))
}
pub fn new_authorized(configs: &Configs) -> Result<Client, RailwayError> {
let mut headers = HeaderMap::new();
if let Some(token) = &Configs::get_railway_token() {
headers.insert("project-access-token", HeaderValue::from_str(token)?);
} else if let Some(token) = configs.get_railway_auth_token() {
headers.insert(
"authorization",
HeaderValue::from_str(&format!("Bearer {token}"))?,
);
} else {
return Err(RailwayError::Unauthorized);
}
headers.insert(
"x-source",
HeaderValue::from_static(consts::get_user_agent()),
);
Ok(Self::build_client(headers))
}
fn build_client(headers: HeaderMap) -> Client {
Client::builder()
.danger_accept_invalid_certs(matches!(Configs::get_environment_id(), Environment::Dev))
.user_agent(consts::get_user_agent())
.default_headers(headers)
.timeout(Duration::from_secs(30))
.build()
.unwrap()
}
}
pub async fn post_graphql<Q: GraphQLQuery, U: reqwest::IntoUrl>(
client: &reqwest::Client,
url: U,
variables: Q::Variables,
) -> Result<Q::ResponseData, RailwayError> {
let body = Q::build_query(variables);
let response = client.post(url).json(&body).send().await?;
if response.status() == 429 {
return Err(RailwayError::Ratelimited);
}
let res: GraphQLResponse<Q::ResponseData> = response.json().await?;
if let Some(errors) = res.errors {
let error = &errors[0];
if error
.message
.to_lowercase()
.contains("project token not found")
{
Err(RailwayError::InvalidRailwayToken(
RAILWAY_TOKEN_ENV.to_string(),
))
} else if error.message.to_lowercase().contains("not authorized") {
// Handle unauthorized errors in a custom way
Err(auth_failure_error())
} else if error.message == "Two Factor Authentication Required" {
// Extract workspace name from extensions if available
let workspace_name = error
.extensions
.as_ref()
.and_then(|ext| ext.get("workspaceName"))
.and_then(|v| v.as_str())
.unwrap_or("this workspace")
.to_string();
let security_url = get_security_url();
Err(RailwayError::TwoFactorEnforcementRequired(
workspace_name,
security_url,
))
} else {
Err(RailwayError::GraphQLError(error.message.clone()))
}
} else if let Some(data) = res.data {
Ok(data)
} else {
Err(RailwayError::MissingResponseData)
}
}
fn get_security_url() -> String {
let host = match Configs::get_environment_id() {
Environment::Production => "railway.com",
Environment::Staging => "railway-staging.com",
Environment::Dev => "railway-develop.com",
};
format!("https://{}/account/security", host)
}
pub(crate) fn auth_failure_error() -> RailwayError {
if Configs::get_railway_token().is_some() {
RailwayError::UnauthorizedToken(RAILWAY_TOKEN_ENV.to_string())
} else if Configs::get_railway_api_token().is_some() {
RailwayError::UnauthorizedToken(RAILWAY_API_TOKEN_ENV.to_string())
} else if Configs::new()
.ok()
.and_then(|configs| configs.get_railway_auth_token())
.is_some()
{
RailwayError::UnauthorizedLogin
} else {
RailwayError::Unauthorized
}
}
/// Ensures the OAuth access token is still valid, refreshing if needed.
pub async fn ensure_valid_token(configs: &mut Configs) -> Result<()> {
// Env var tokens are not managed by us
if Configs::get_railway_token().is_some() || Configs::get_railway_api_token().is_some() {
return Ok(());
}
if !configs.has_oauth_token() || !configs.is_token_expired() {
return Ok(());
}
let refresh_token = configs.get_refresh_token().ok_or_else(|| {
RailwayError::OAuthRefreshFailed("No refresh token available".to_string())
})?;
let host = configs.get_host();
let token_resp = oauth::refresh_access_token(host, refresh_token).await?;
configs.save_oauth_tokens(
&token_resp.access_token,
token_resp.refresh_token.as_deref(),
token_resp.expires_in,
)?;
Ok(())
}
/// Like post_graphql, but removes null values from the variables object before sending.
///
/// This is needed because graphql-client 0.14.0 has a bug where skip_serializing_none
/// doesn't work for root-level variables (only nested ones). This causes None values
/// to be serialized as null, which tells the Railway API to unset fields.
///
/// By stripping nulls from the JSON, we ensure the API receives undefined instead,
/// which preserves existing values (e.g., cron schedules on function updates).
pub async fn post_graphql_skip_none<Q: GraphQLQuery, U: reqwest::IntoUrl>(
client: &reqwest::Client,
url: U,
variables: Q::Variables,
) -> Result<Q::ResponseData, RailwayError> {
let body = Q::build_query(variables);
let mut body_json =
serde_json::to_value(&body).expect("Failed to serialize GraphQL query body");
if let Some(obj) = body_json.as_object_mut() {
if let Some(vars) = obj.get_mut("variables").and_then(|v| v.as_object_mut()) {
vars.retain(|_, v| !v.is_null());
}
}
let response = client.post(url).json(&body_json).send().await?;
if response.status() == 429 {
return Err(RailwayError::Ratelimited);
}
let res: GraphQLResponse<Q::ResponseData> = response.json().await?;
if let Some(errors) = res.errors {
let error = &errors[0];
if error
.message
.to_lowercase()
.contains("project token not found")
{
Err(RailwayError::InvalidRailwayToken(
RAILWAY_TOKEN_ENV.to_string(),
))
} else if error.message.to_lowercase().contains("not authorized") {
Err(auth_failure_error())
} else if error.message == "Two Factor Authentication Required" {
// Extract workspace name from extensions if available
let workspace_name = error
.extensions
.as_ref()
.and_then(|ext| ext.get("workspaceName"))
.and_then(|v| v.as_str())
.unwrap_or("this workspace")
.to_string();
let security_url = get_security_url();
Err(RailwayError::TwoFactorEnforcementRequired(
workspace_name,
security_url,
))
} else {
Err(RailwayError::GraphQLError(error.message.clone()))
}
} else if let Some(data) = res.data {
Ok(data)
} else {
Err(RailwayError::MissingResponseData)
}
}
#[cfg(test)]
mod tests {
use std::{
io::{BufRead, BufReader, Read, Write},
net::TcpListener,
thread,
};
use super::*;
use crate::gql::queries;
fn spawn_graphql_server(
response_for_request: impl FnOnce(String) -> String + Send + 'static,
) -> String {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut reader = BufReader::new(stream.try_clone().unwrap());
let mut request = String::new();
let mut content_length = 0usize;
loop {
let mut line = String::new();
reader.read_line(&mut line).unwrap();
request.push_str(&line);
if let Some(value) = line.strip_prefix("Content-Length:") {
content_length = value.trim().parse().unwrap();
}
if line == "\r\n" {
break;
}
}
let mut body = vec![0; content_length];
reader.read_exact(&mut body).unwrap();
request.push_str(std::str::from_utf8(&body).unwrap());
let response_body = response_for_request(request);
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes()).unwrap();
});
format!("http://{addr}")
}
#[tokio::test]
async fn public_client_can_query_templates_without_auth_headers() {
let server_url = spawn_graphql_server(|request| {
assert!(
!request.to_ascii_lowercase().contains("authorization:"),
"public template lookup should not send auth headers"
);
serde_json::json!({
"data": {
"template": {
"id": "template-id",
"name": "PostgreSQL",
"serializedConfig": null
}
}
})
.to_string()
});
let client = GQLClient::new_public().unwrap();
let response = post_graphql::<queries::TemplateDetail, _>(
&client,
server_url,
queries::template_detail::Variables {
code: "postgres".to_string(),
},
)
.await
.unwrap();
assert_eq!(response.template.id, "template-id");
assert_eq!(response.template.name, "PostgreSQL");
assert_eq!(response.template.serialized_config, None);
}
}