diff --git a/Cargo.lock b/Cargo.lock index f190c74..d9e3096 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "country-emoji" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93ed3c15fd433a3e8f7e52e70968113d4cd572d84a9454d1899f64c72872f02" +dependencies = [ + "lazy_static", + "regex", +] + [[package]] name = "cpufeatures" version = "0.2.16" @@ -1519,6 +1529,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "json_dotpath" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbdcfef3cf5591f0cef62da413ae795e3d1f5a00936ccec0b2071499a32efd1a" +dependencies = [ + "serde", + "serde_derive", + "serde_json", + "thiserror 1.0.69", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1905,6 +1927,7 @@ dependencies = [ "clap_complete", "colored", "console", + "country-emoji", "ctrlc", "derive-new", "dirs", @@ -1921,6 +1944,7 @@ dependencies = [ "indoc", "inquire", "is-terminal", + "json_dotpath", "names", "num_cpus", "open", @@ -1931,6 +1955,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "struct-field-names-as-array", "structstruck", "strum", "synchronized-writer", @@ -2474,6 +2499,26 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "struct-field-names-as-array" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ba4bae771f9cc992c4f403636c54d2ef13acde6367583e99d06bb336674dd9" +dependencies = [ + "struct-field-names-as-array-derive", +] + +[[package]] +name = "struct-field-names-as-array-derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2dbf8b57f3ce20e4bb171a11822b283bdfab6c4bb0fe64fa729f045f23a0938" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.93", +] + [[package]] name = "structstruck" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 59b2840..27267f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,9 @@ strum = { version = "0.26.3", features = ["derive"] } structstruck = "0.4.1" derive-new = "0.7.0" regex = "1.11.1" +struct-field-names-as-array = "0.3.0" +json_dotpath = "1.1.0" +country-emoji = "0.2.0" [profile.release] lto = "fat" diff --git a/src/commands/scale.rs b/src/commands/scale.rs index 1842d23..c22bf91 100644 --- a/src/commands/scale.rs +++ b/src/commands/scale.rs @@ -1,26 +1,329 @@ -use clap::{Arg, Command}; +use crate::{ + consts::TICK_STRING, + controllers::environment::get_matched_environment, + errors::RailwayError, + util::prompt::{ + prompt_select_with_cancel, prompt_text, + prompt_u64_with_placeholder_and_validation_and_cancel, + }, +}; +use anyhow::bail; +use clap::{Arg, Command, Parser}; +use country_emoji::flag; use futures::executor::block_on; -use serde_json::json; -use std::collections::HashMap; +use is_terminal::IsTerminal; +use json_dotpath::DotPaths as _; +use serde_json::{json, Map, Value}; +use std::{cmp::Ordering, collections::HashMap, fmt::Display, time::Duration}; +use struct_field_names_as_array::FieldNamesAsArray; use super::*; /// Dynamic flags workaround /// Unfortunately, we aren't able to use the Parser derive macro when working with dynamic flags, /// meaning we have to implement most of the traits for the Args struct manually. +struct DynamicArgs(HashMap); + +#[derive(Parser, FieldNamesAsArray)] pub struct Args { - // This field will collect any of the dynamically generated flags - pub dynamic: HashMap, + #[clap(flatten)] + dynamic: DynamicArgs, + + /// The service to scale (defaults to linked service) + #[clap(long, short)] + service: Option, + + /// The environment the service is in (defaults to linked environment) + #[clap(long, short)] + environment: Option, } pub async fn command(args: Args, _json: bool) -> Result<()> { - let mut configs = Configs::new()?; + let configs = Configs::new()?; let client = GQLClient::new_authorized(&configs)?; let linked_project = configs.get_linked_project().await?; + let project = post_graphql::( + &client, + configs.get_backboard(), + queries::project::Variables { + id: linked_project.project.clone(), + }, + ) + .await? + .project; + let environment = args + .environment + .clone() + .unwrap_or(linked_project.environment.clone()); + let (existing, latest_id) = get_existing_config(&args, &linked_project, project, environment)?; + let new_config = convert_hashmap_into_map( + if args.dynamic.0.is_empty() && std::io::stdout().is_terminal() { + prompt_for_regions(&configs, &client, &existing).await? + } else if args.dynamic.0.is_empty() { + bail!("Please specify regions via the flags when not running in a terminal") + } else { + args.dynamic.0 + }, + ); + if new_config.is_empty() { + println!("No changes made"); + return Ok(()); + } + let region_data = merge_config(existing, new_config); + handle_2fa(&configs, &client).await?; + update_regions_and_redeploy(configs, client, linked_project, latest_id, region_data).await?; - Ok(()) } +async fn prompt_for_regions( + configs: &Configs, + client: &reqwest::Client, + existing: &Value, +) -> Result> { + let mut updated: HashMap = HashMap::new(); + let mut regions = post_graphql::( + client, + configs.get_backboard(), + queries::regions::Variables, + ) + .await + .expect("couldn't get regions"); + loop { + let get_replicas_amount = |name: String| { + let before = if let Some(num) = existing.get(name.clone()) { + num.get("numReplicas").unwrap().as_u64().unwrap() // fine to unwrap, API only returns ones that have a replica + } else { + 0 + }; + let after = if let Some(new_value) = updated.get(&name) { + *new_value + } else { + before + }; + (before, after) + }; + regions.regions.sort_by(|a, b| { + get_replicas_amount(b.name.clone()) + .1 + .cmp(&get_replicas_amount(a.name.clone()).1) + }); + let regions = regions + .regions + .iter() + .map(|f| { + PromptRegion( + f.clone(), + format!( + "{} {}{}{}", + flag(&f.country).unwrap_or_default(), + f.location, + if f.railway_metal.unwrap_or_default() { + " (METAL)".bold().purple().to_string() + } else { + String::new() + }, + { + let (before, after) = get_replicas_amount(f.name.clone()); + let amount = format!( + " ({} replica{})", + after, + if after == 1 { "" } else { "s" } + ); + match after.cmp(&before) { + Ordering::Equal if after == 0 => String::new().normal(), + Ordering::Equal => amount.yellow(), + Ordering::Greater => amount.green(), + Ordering::Less => amount.red(), + } + .to_string() + } + ), + ) + }) + .collect::>(); + let p = prompt_select_with_cancel("Select a region ", regions)?; + if let Some(region) = p { + let amount_before = if let Some(updated) = updated.get(®ion.0.name) { + *updated + } else if let Some(previous) = existing.as_object().unwrap().get(®ion.0.name) { + previous.get("numReplicas").unwrap().as_u64().unwrap() + } else { + 0 + }; + let prompted = prompt_u64_with_placeholder_and_validation_and_cancel( + format!( + "Enter the amount of replicas for {} ", + region.0.name.clone() + ) + .as_str(), + amount_before.to_string().as_str(), + )?; + if let Some(prompted) = prompted { + let parse: u64 = prompted.parse()?; + updated.insert(region.0.name.clone(), parse); + } else { + // esc pressed when entering number, go back to selecting regions + continue; + } + } else { + // they pressed esc to cancel + break; + } + } + Ok(updated.clone()) +} + +async fn update_regions_and_redeploy( + configs: Configs, + client: reqwest::Client, + linked_project: LinkedProject, + latest_id: Option, + region_data: Value, +) -> Result<(), anyhow::Error> { + let spinner = indicatif::ProgressBar::new_spinner() + .with_style( + indicatif::ProgressStyle::default_spinner() + .tick_chars(TICK_STRING) + .template("{spinner:.green} {msg}")?, + ) + .with_message("Updating regions..."); + spinner.enable_steady_tick(Duration::from_millis(100)); + post_graphql::( + &client, + configs.get_backboard(), + mutations::update_regions::Variables { + environment_id: linked_project.environment, + service_id: linked_project.service.unwrap(), + multi_region_config: region_data, + }, + ) + .await?; + spinner.finish_with_message("Regions updated"); + if let Some(latest) = latest_id { + let spinner = indicatif::ProgressBar::new_spinner() + .with_style( + indicatif::ProgressStyle::default_spinner() + .tick_chars(TICK_STRING) + .template("{spinner:.green} {msg}")?, + ) + .with_message("Redeploying..."); + spinner.enable_steady_tick(Duration::from_millis(100)); + post_graphql::( + &client, + configs.get_backboard(), + mutations::deployment_redeploy::Variables { id: latest }, + ) + .await?; + spinner.finish_with_message("Redeployed"); + }; + Ok(()) +} + +fn merge_config(existing: Value, new_config: Map) -> Value { + let mut map = match existing { + Value::Object(object) => object, + _ => unreachable!(), // will always be a map + }; + map.extend(new_config); + Value::Object(map) +} + +async fn handle_2fa(configs: &Configs, client: &reqwest::Client) -> Result<(), anyhow::Error> { + let is_two_factor_enabled = { + let vars = queries::two_factor_info::Variables {}; + + let info = post_graphql::(client, configs.get_backboard(), vars) + .await? + .two_factor_info; + + info.is_verified + }; + if is_two_factor_enabled { + let token = prompt_text("Enter your 2FA code")?; + let vars = mutations::validate_two_factor::Variables { token }; + + let valid = + post_graphql::(client, configs.get_backboard(), vars) + .await? + .two_factor_info_validate; + + if !valid { + return Err(RailwayError::InvalidTwoFactorCode.into()); + } + }; + Ok(()) +} + +fn convert_hashmap_into_map(map: HashMap) -> Map { + let new_config = map.iter().fold(Map::new(), |mut map, (key, val)| { + map.insert( + key.clone(), + if *val == 0 { + Value::Null // this is how the dashboard does it + } else { + json!({ "numReplicas": val }) + }, + ); + map + }); + new_config +} + +fn get_existing_config( + args: &Args, + linked_project: &LinkedProject, + project: queries::project::ProjectProject, + environment: String, +) -> Result<(Value, Option), anyhow::Error> { + let environment_id = get_matched_environment(&project, environment)?.id; + let service_input: &String = args.service.as_ref().unwrap_or(linked_project.service.as_ref().expect("No service linked. Please either specify a service with the --service flag or link one with `railway service`")); + let mut id: Option = None; + let service_meta = if let Some(service) = project.services.edges.iter().find(|p| { + (p.node.id == *service_input) + || (p.node.name.to_lowercase() == service_input.to_lowercase()) + }) { + // check that service exists in that environment + if let Some(instance) = service + .node + .service_instances + .edges + .iter() + .find(|p| p.node.environment_id == environment_id) + { + if let Some(latest) = &instance.node.latest_deployment { + id = Some(latest.id.clone()); + if let Some(meta) = &latest.meta { + let deploy = meta + .dot_get::("serviceManifest.deploy")? + .expect("Very old deployment, please redeploy"); + if let Some(c) = deploy.dot_get::("multiRegionConfig")? { + Some(c) + } else if let Some(region) = deploy.dot_get::("region")? { + // old deployments only have numReplicas and a region field... + let mut map = Map::new(); + let replicas = deploy.dot_get::("numReplicas")?.unwrap_or(json!(1)); + map.insert(region.to_string(), json!({ "numReplicas": replicas })); + Some(json!({ + "multiRegionConfig": map + })) + } else { + None + } + } else { + None + } + } else { + None + } + } else { + bail!("Service not found in the environment") + } + } else { + None + }; + Ok((service_meta.unwrap_or(Value::Object(Map::new())), id)) +} + /// This function generates flags that are appended to the command at runtime. pub fn get_dynamic_args(cmd: Command) -> Command { if !std::env::args().any(|f| f.eq_ignore_ascii_case("scale")) { @@ -62,21 +365,21 @@ pub fn get_dynamic_args(cmd: Command) -> Command { }) } -impl clap::FromArgMatches for Args { +impl clap::FromArgMatches for DynamicArgs { fn from_arg_matches(matches: &clap::ArgMatches) -> Result { let mut dynamic = HashMap::new(); // Iterate through all provided argument keys. // Adjust the static key names if you add any to your Args struct. for key in matches.ids() { - if key == "json" { + if key == "json" || Args::FIELD_NAMES_AS_ARRAY.contains(&key.as_str()) { continue; } - // If the flag value can be interpreted as a u16, insert it. - if let Some(val) = matches.get_one::(key.as_str()) { + // If the flag value can be interpreted as a u64, insert it. + if let Some(val) = matches.get_one::(key.as_str()) { dynamic.insert(key.to_string(), *val); } } - Ok(Args { dynamic }) + Ok(DynamicArgs(dynamic)) } fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> { @@ -85,41 +388,36 @@ impl clap::FromArgMatches for Args { } } -impl clap::Args for Args { +impl clap::Args for DynamicArgs { fn group_id() -> Option { - Some(clap::Id::from("Args")) + // Do not create an argument group for dynamic flags + None } - fn augment_args<'b>(__clap_app: clap::Command) -> clap::Command { - { - let __clap_app = __clap_app.group(clap::ArgGroup::new("Args").multiple(true).args({ - let members: [clap::Id; 0usize] = []; - members - })); - __clap_app - .about("Control the number of instances running in each region") - .long_about(None) - } + fn augment_args(cmd: clap::Command) -> clap::Command { + // Leave the command unchanged; dynamic flags will be handled via FromArgMatches + cmd } - fn augment_args_for_update<'b>(__clap_app: clap::Command) -> clap::Command { - { - let __clap_app = __clap_app.group(clap::ArgGroup::new("Args").multiple(true).args({ - let members: [clap::Id; 0usize] = []; - members - })); - __clap_app - .about("Control the number of instances running in each region") - .long_about(None) - } + fn augment_args_for_update(cmd: clap::Command) -> clap::Command { + cmd } } -impl clap::CommandFactory for Args { +impl clap::CommandFactory for DynamicArgs { fn command<'b>() -> clap::Command { let __clap_app = clap::Command::new("railwayapp"); - ::augment_args(__clap_app) + ::augment_args(__clap_app) } fn command_for_update<'b>() -> clap::Command { let __clap_app = clap::Command::new("railwayapp"); - ::augment_args_for_update(__clap_app) + ::augment_args_for_update(__clap_app) + } +} + +/// Formatting done manually +pub struct PromptRegion(pub queries::regions::RegionsRegions, pub String); + +impl Display for PromptRegion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.1) } } diff --git a/src/gql/mutations/mod.rs b/src/gql/mutations/mod.rs index 5d1e929..1723473 100644 --- a/src/gql/mutations/mod.rs +++ b/src/gql/mutations/mod.rs @@ -167,8 +167,7 @@ pub struct CustomDomainCreate; #[graphql( schema_path = "src/gql/schema.json", query_path = "src/gql/mutations/strings/UpdateRegions.graphql", - response_derives = "Debug, Serialize, Clone", - skip_serializing_none + response_derives = "Debug, Serialize, Clone" )] pub struct UpdateRegions; diff --git a/src/gql/queries/mod.rs b/src/gql/queries/mod.rs index 028dbcc..ff9f61a 100644 --- a/src/gql/queries/mod.rs +++ b/src/gql/queries/mod.rs @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize}; type DateTime = chrono::DateTime; type EnvironmentVariables = std::collections::BTreeMap>; +//type DeploymentMeta = std::collections::BTreeMap; +type DeploymentMeta = serde_json::Value; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] diff --git a/src/gql/queries/strings/Project.graphql b/src/gql/queries/strings/Project.graphql index 29c36f2..a28bdd3 100644 --- a/src/gql/queries/strings/Project.graphql +++ b/src/gql/queries/strings/Project.graphql @@ -29,6 +29,7 @@ query Project($id: String!) { latestDeployment { canRedeploy id + meta } source { repo diff --git a/src/gql/queries/strings/Regions.graphql b/src/gql/queries/strings/Regions.graphql index 68ed003..dec5c12 100644 --- a/src/gql/queries/strings/Regions.graphql +++ b/src/gql/queries/strings/Regions.graphql @@ -1,5 +1,8 @@ query Regions { regions { name + country + railwayMetal + location } } \ No newline at end of file diff --git a/src/util/prompt.rs b/src/util/prompt.rs index ef33aa7..6d9e05f 100644 --- a/src/util/prompt.rs +++ b/src/util/prompt.rs @@ -1,4 +1,5 @@ use colored::*; +use inquire::validator::{Validation, ValueRequiredValidator}; use std::fmt::Display; use crate::commands::{queries::project::ProjectProjectServicesEdgesNode, Configs}; @@ -28,6 +29,27 @@ pub fn prompt_text(message: &str) -> Result { .context("Failed to prompt for options") } +pub fn prompt_u64_with_placeholder_and_validation_and_cancel( + message: &str, + placeholder: &str, +) -> Result> { + let validator = |input: &str| { + if input.parse::().is_ok() { + Ok(Validation::Valid) + } else { + Ok(Validation::Invalid("Not a valid number".into())) + } + }; + let select = inquire::Text::new(message); + select + .with_render_config(Configs::get_render_config()) + .with_placeholder(placeholder) + .with_validator(ValueRequiredValidator::new("Input most not be empty")) + .with_validator(validator) + .prompt_skippable() + .context("Failed to prompt for options") +} + pub fn prompt_text_with_placeholder_if_blank( message: &str, placeholder: &str, @@ -105,6 +127,13 @@ pub fn prompt_select(message: &str, options: Vec) -> Result { .context("Failed to prompt for select") } +pub fn prompt_select_with_cancel(message: &str, options: Vec) -> Result> { + inquire::Select::new(message, options) + .with_render_config(Configs::get_render_config()) + .prompt_skippable() + .context("Failed to prompt for select") +} + pub fn fake_select(message: &str, selected: &str) { println!("{} {} {}", ">".green(), message, selected.cyan().bold()); }