diff --git a/Cargo.lock b/Cargo.lock index b111a1e6..81c30e37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "futures-util", "jiff", "reqwest", + "serde", "tokio", "tracing", "tracing-subscriber", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 90dac43d..c8b0e3a8 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -78,3 +78,9 @@ workspace = true [dependencies.tracing-subscriber] workspace = true features = ["env-filter"] + +# The API uses 'serde' for transforming high-level types to and from the +# underlying wire format. The CLI needs to deserialize them. +[dependencies.serde] +workspace = true +features = ["derive"] diff --git a/crates/cli/src/client.rs b/crates/cli/src/client.rs index 2c2937d6..fc70956b 100644 --- a/crates/cli/src/client.rs +++ b/crates/cli/src/client.rs @@ -1,6 +1,9 @@ +use std::error::Error; use std::time::Duration; use reqwest::{IntoUrl, Method, RequestBuilder}; +use serde::Serialize; +use serde::de::DeserializeOwned; use tracing::debug; use url::Url; @@ -33,24 +36,115 @@ impl CascadeApiClient { client.request(method, path) } + #[allow(dead_code)] pub fn get(&self, s: &str) -> RequestBuilder { self.request(Method::GET, s) } + #[allow(dead_code)] pub fn post(&self, s: &str) -> RequestBuilder { self.request(Method::POST, s) } + + #[allow(dead_code)] + pub async fn get_json_with(&self, s: &str, payload: &P) -> Result + where + T: DeserializeOwned, + P: Serialize, + { + send_format_decode(self.request(Method::GET, s).json(payload)).await + } + + pub async fn post_json_with(&self, s: &str, payload: &P) -> Result + where + T: DeserializeOwned, + P: Serialize, + { + send_format_decode(self.request(Method::POST, s).json(payload)).await + } + + pub async fn get_json(&self, s: &str) -> Result + where + T: DeserializeOwned, + { + send_format_decode(self.request(Method::GET, s)).await + } + + pub async fn post_json(&self, s: &str) -> Result + where + T: DeserializeOwned, + { + send_format_decode(self.request(Method::POST, s)).await + } } +pub async fn send_format_decode(req: RequestBuilder) -> Result +where + T: DeserializeOwned, +{ + req.send() + .await + .map_err(format_http_error)? // Format connection errors + .error_for_status() + .map_err(format_http_error)? // Format status code errors + .json() + .await + .map_err(format_http_error) // Format decoding errors +} + +/// Format HTTP errors with message based on error type, and chain error +/// descriptions together instead of simply printing the Debug representation +/// (which is confusing for users). pub fn format_http_error(err: reqwest::Error) -> String { - if err.is_decode() { - // Use the debug representation of decoding errors otherwise the cause - // of the decoding failure, e.g. the underlying Serde error, gets lost - // and makes determining why the response couldn't be decoded a game - // of divide and conquer removing response fields one by one until the - // offending field is determined. - format!("HTTP request failed: {err:?}") + let mut message = String::new(); + + // Returning a shortened timed out message to not have a redundant text + // like: "... HTTP connection timed out: operation timed out" + if err.is_timeout() { + // "Returns true if the error is related to a timeout." [1] + return String::from("HTTP connection timed out"); + } + + // [1]: https://docs.rs/reqwest/latest/reqwest/struct.Error.html + if err.is_connect() { + // "Returns true if the error is related to connect" [1] + message.push_str("HTTP connection failed"); + } else if err.is_status() { + // "Returns true if the error is from Response::error_for_status" [1] + message.push_str("HTTP request failed with status code "); + if let Some(status) = err.status() { + message.push_str(status.as_str()); + } else { + // This should not happen, as we get into this branch from + // Response::error_for_status. + message.push_str(""); + } + } else if err.is_decode() { + // "Returns true if the error is related to decoding the response’s body" [1] + // Originally, we used the debug representation to be able to see all + // fields related to the error and make finding the offending field + // easier. This was confusing for users. Now we print the "source()" + // of the error below, which contains the relevant information. + message.push_str("HTTP response decoding failed"); } else { - format!("HTTP request failed: {err}") + // Covers unknown errors, non-OK HTTP status codes, errors "related to + // the request" [1], errors "related to the request or response body" + // [1], errors "from a type Builder" [1], errors "from + // a RedirectPolicy." [1], errors "related to a protocol upgrade + // request" [1] + message.push_str("HTTP request failed"); + } + + // Chain error sources together to capture all relevant error parts. E.g.: + // "client error (Connect): tcp connect error: Connection refused (os error 111)" + // instead of just "client error (Connect)"; + // and "client error (SendRequest): connection closed before message completed" + // instead of just "client error (SendRequest)" + let mut we = err.source(); + while let Some(e) = we { + message.push_str(&format!(": {e}")); + we = e.source(); } + + message } diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index fdd94e28..cbaa8889 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -1,8 +1,6 @@ -use futures_util::TryFutureExt; - use crate::{ api::{self, ChangeLogging, ChangeLoggingResult, TraceTarget}, - client::{CascadeApiClient, format_http_error}, + client::CascadeApiClient, println, }; @@ -42,15 +40,14 @@ impl Debug { let level = level.map(Into::into); let trace_targets = trace_targets.map(|t| t.into_iter().map(TraceTarget).collect()); let (): ChangeLoggingResult = client - .post("debug/change-logging") - .json(&ChangeLogging { - level, - trace_targets, - }) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + .post_json_with( + "debug/change-logging", + &ChangeLogging { + level, + trace_targets, + }, + ) + .await?; println!("Updated logging behavior"); Ok(()) diff --git a/crates/cli/src/commands/hsm.rs b/crates/cli/src/commands/hsm.rs index 023e6ace..69bfaedd 100644 --- a/crates/cli/src/commands/hsm.rs +++ b/crates/cli/src/commands/hsm.rs @@ -15,7 +15,6 @@ use std::{ }; use clap::Subcommand; -use futures_util::TryFutureExt; use jiff::{Span, SpanRelativeTo}; use crate::{ @@ -23,7 +22,7 @@ use crate::{ HsmServerAdd, HsmServerAddError, HsmServerAddResult, HsmServerGetResult, HsmServerListResult, KmipServerState, PolicyInfo, PolicyInfoError, PolicyListResult, }, - client::{CascadeApiClient, format_http_error}, + client::CascadeApiClient, println, }; @@ -71,29 +70,28 @@ impl Hsm { let ca_cert = read_binary_file(ca_cert_path.as_ref()).map_err(|e| e.to_string())?; let res: Result = client - .post("kmip") - .json(&HsmServerAdd { - server_id, - ip_host_or_fqdn, - port, - username, - password, - client_cert, - client_key, - insecure, - server_cert, - ca_cert, - connect_timeout, - read_timeout, - write_timeout, - max_response_bytes, - key_label_prefix, - key_label_max_bytes, - }) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + .post_json_with( + "kmip", + &HsmServerAdd { + server_id, + ip_host_or_fqdn, + port, + username, + password, + client_cert, + client_key, + insecure, + server_cert, + ca_cert, + connect_timeout, + read_timeout, + write_timeout, + max_response_bytes, + key_label_prefix, + key_label_max_bytes, + }, + ) + .await?; match res { Ok(HsmServerAddResult { vendor_id }) => { @@ -104,12 +102,7 @@ impl Hsm { } HsmCommand::ListServers => { - let res: HsmServerListResult = client - .get("kmip") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: HsmServerListResult = client.get_json("kmip").await?; for server in res.servers { println!("{server}"); @@ -117,12 +110,8 @@ impl Hsm { } HsmCommand::GetServer { server_id } => { - let res: Result = client - .get(&format!("kmip/{server_id}")) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = + client.get_json(&format!("kmip/{server_id}")).await?; match res { Ok(res) => { @@ -158,19 +147,10 @@ async fn get_policy_names_using_hsm( server_id: &String, ) -> Result, String> { let mut policies_using_hsm = vec![]; - let res: PolicyListResult = client - .get("policy/") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: PolicyListResult = client.get_json("policy/").await?; for policy_name in res.policies { - let res: Result = client - .get(&format!("policy/{policy_name}")) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = + client.get_json(&format!("policy/{policy_name}")).await?; let p = match res { Ok(p) => p, diff --git a/crates/cli/src/commands/keyset.rs b/crates/cli/src/commands/keyset.rs index 9a6d0196..1ce0b448 100644 --- a/crates/cli/src/commands/keyset.rs +++ b/crates/cli/src/commands/keyset.rs @@ -1,8 +1,6 @@ -use futures_util::TryFutureExt; - use crate::api::ZoneName; use crate::api::keyset as api; -use crate::client::{CascadeApiClient, format_http_error}; +use crate::client::CascadeApiClient; use crate::println; #[derive(Clone, Debug, clap::Args)] @@ -131,15 +129,15 @@ async fn roll_command( variant: api::KeyRollVariant, ) -> Result<(), String> { let res: Result<(), String> = client - .post(&format!("key/{zone}/roll")) - .json(&api::KeyRoll { - variant, - cmd: cmd.into(), - }) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + .post_json_with( + &format!("key/{zone}/roll"), + &api::KeyRoll { + variant, + cmd: cmd.into(), + }, + ) + .await?; + match res { Ok(_) => { println!("Manual key roll for {} successful", zone); @@ -157,16 +155,16 @@ async fn remove_key_command( continue_flag: bool, ) -> Result<(), String> { let res: Result<(), String> = client - .post(&format!("key/{zone}/remove")) - .json(&api::KeyRemove { - key: key.clone(), - force, - continue_flag, - }) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + .post_json_with( + &format!("key/{zone}/remove"), + &api::KeyRemove { + key: key.clone(), + force, + continue_flag, + }, + ) + .await?; + match res { Ok(_) => { println!("Removed key {} from zone {}", key, zone); diff --git a/crates/cli/src/commands/mod.rs b/crates/cli/src/commands/mod.rs index a1bd2ca9..5e19da16 100644 --- a/crates/cli/src/commands/mod.rs +++ b/crates/cli/src/commands/mod.rs @@ -8,7 +8,7 @@ pub mod status; pub mod template; pub mod zone; -use crate::client::{CascadeApiClient, format_http_error}; +use crate::client::CascadeApiClient; use crate::println; #[allow(clippy::large_enum_variant)] @@ -65,11 +65,7 @@ impl Command { match self { Self::Debug(cmd) => cmd.execute(client).await, Self::Health => { - client - .get("health") - .send() - .await - .map_err(format_http_error)?; + client.get_json::<()>("health").await?; println!("Ok"); Ok(()) } diff --git a/crates/cli/src/commands/policy.rs b/crates/cli/src/commands/policy.rs index dac20454..f88d3692 100644 --- a/crates/cli/src/commands/policy.rs +++ b/crates/cli/src/commands/policy.rs @@ -1,5 +1,3 @@ -use futures_util::TryFutureExt; - use crate::{ ansi, api::{ @@ -7,7 +5,7 @@ use crate::{ PolicyListResult, PolicyReloadError, ReviewPolicyInfo, SignerDenialPolicyInfo, SignerSerialPolicyInfo, }, - client::{CascadeApiClient, format_http_error}, + client::CascadeApiClient, println, }; @@ -36,24 +34,15 @@ impl Policy { pub async fn execute(self, client: CascadeApiClient) -> Result<(), String> { match self.command { PolicyCommand::List => { - let res: PolicyListResult = client - .get("policy/") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: PolicyListResult = client.get_json("policy/").await?; for policy in res.policies { println!("{policy}"); } } PolicyCommand::Show { name } => { - let res: Result = client - .get(&format!("policy/{name}")) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = + client.get_json(&format!("policy/{name}")).await?; let p = match res { Ok(p) => p, @@ -65,12 +54,8 @@ impl Policy { print_policy(&p); } PolicyCommand::Reload => { - let res: Result = client - .post("policy/reload") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = + client.post_json("policy/reload").await?; let res = match res { Ok(res) => res, diff --git a/crates/cli/src/commands/status.rs b/crates/cli/src/commands/status.rs index bf809e7a..a6618b33 100644 --- a/crates/cli/src/commands/status.rs +++ b/crates/cli/src/commands/status.rs @@ -1,8 +1,6 @@ -use futures_util::TryFutureExt; - use crate::ansi; use crate::api::{KeyMsg, KeyStatusResult, KeysPerZone, ServerStatusResult, SigningStageReport}; -use crate::client::{CascadeApiClient, format_http_error}; +use crate::client::CascadeApiClient; use crate::{eprintln, println}; #[derive(Clone, Debug, clap::Args)] @@ -27,12 +25,7 @@ impl Status { pub async fn execute(self, client: CascadeApiClient) -> Result<(), String> { match self.command { Some(StatusCommand::Keys) => { - let response: KeyStatusResult = client - .get("/status/keys") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let response: KeyStatusResult = client.get_json("/status/keys").await?; println!("First to expire (max 5):"); if response.expirations.is_empty() { @@ -64,12 +57,7 @@ impl Status { } } None => { - let response: ServerStatusResult = client - .get("/status") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let response: ServerStatusResult = client.get_json("/status").await?; if !response.hard_halted_zones.is_empty() { eprintln!("The following zones are hard halted due to a serious problem:"); diff --git a/crates/cli/src/commands/zone.rs b/crates/cli/src/commands/zone.rs index d496531f..aa291ee4 100644 --- a/crates/cli/src/commands/zone.rs +++ b/crates/cli/src/commands/zone.rs @@ -2,11 +2,10 @@ use std::ops::ControlFlow; use std::time::{Duration, SystemTime}; use camino::Utf8PathBuf; -use futures_util::TryFutureExt; use crate::ansi; use crate::api::*; -use crate::client::{CascadeApiClient, format_http_error}; +use crate::client::CascadeApiClient; use crate::println; #[derive(Clone, Debug, clap::Args)] @@ -199,17 +198,16 @@ impl Zone { } let res: Result = client - .post("zone/add") - .json(&ZoneAdd { - name, - source, - policy, - key_imports, - }) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + .post_json_with( + "zone/add", + &ZoneAdd { + name, + source, + policy, + key_imports, + }, + ) + .await?; match res { Ok(res) => { @@ -220,12 +218,8 @@ impl Zone { } } ZoneCommand::Remove { name } => { - let res: Result = client - .post(&format!("zone/{name}/remove")) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = + client.post_json(&format!("zone/{name}/remove")).await?; match res { Ok(res) => { @@ -236,12 +230,7 @@ impl Zone { } } ZoneCommand::List => { - let response: ZonesListResult = client - .get("zone/") - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let response: ZonesListResult = client.get_json("zone/").await?; for zone_name in response.zones { println!("{}", zone_name); @@ -250,12 +239,7 @@ impl Zone { } ZoneCommand::Reload { zone } => { let url = format!("zone/{zone}/reload"); - let res: Result = client - .post(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(format_http_error)?; + let res: Result = client.post_json(&url).await?; match res { Ok(res) => { @@ -283,12 +267,7 @@ impl Zone { }; let url = format!("/zone/{name}/{stage}/{serial}/approve"); - let result: ZoneReviewResult = client - .post(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(|e| format!("HTTP request failed: {e:?}"))?; + let result: ZoneReviewResult = client.post_json(&url).await?; match result { Ok(ZoneReviewOutput {}) => { @@ -321,12 +300,7 @@ impl Zone { }; let url = format!("/zone/{name}/{stage}/{serial}/reject"); - let result: ZoneReviewResult = client - .post(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(|e| format!("HTTP request failed: {e:?}"))?; + let result: ZoneReviewResult = client.post_json(&url).await?; match result { Ok(ZoneReviewOutput {}) => { @@ -343,12 +317,7 @@ impl Zone { } ZoneCommand::Status { zone, detailed } => { let url = format!("zone/{}/status", zone); - let response: Result = client - .get(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(|e| format!("HTTP request failed: {e:?}"))?; + let response: Result = client.get_json(&url).await?; match response { Ok(status) => Self::print_zone_status(client, status, detailed).await, @@ -359,12 +328,7 @@ impl Zone { } ZoneCommand::History { zone } => { let url = format!("zone/{}/history", zone); - let response: Result = client - .get(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(|e| format!("HTTP request failed: {e:?}"))?; + let response: Result = client.get_json(&url).await?; match response { Ok(response) => { @@ -476,12 +440,7 @@ impl Zone { ) -> Result<(), String> { // Fetch the policy for the zone. let url = format!("policy/{}", zone.policy); - let response: Result = client - .get(&url) - .send() - .and_then(|r| r.json()) - .await - .map_err(|e| format!("HTTP request failed: {e:?}"))?; + let response: Result = client.get_json(&url).await?; let policy = response.map_err(|_| { format!( diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index c1d17f5a..449adc01 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -21,7 +21,7 @@ async fn main() -> ExitCode { match args.execute().await { Ok(_) => ExitCode::SUCCESS, Err(err) => { - error!("Error: {err}"); + error!("{err}"); ExitCode::FAILURE } }