From 2d7d75e6ca32e58c1ed53db6f20ca80aa85842f1 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Fri, 19 Dec 2025 01:28:22 +0000 Subject: [PATCH] Refine the SDK api Signed-off-by: kerthcet --- Cargo.lock | 5 -- src/client/client.rs | 33 ++++++----- src/config.rs | 131 +++++++++++++++++++------------------------ src/router/random.rs | 6 +- src/router/router.rs | 6 +- src/router/weight.rs | 6 +- 6 files changed, 84 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f28f7e7..1b9ec19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3,10 +3,6 @@ version = 4 [[package]] -<<<<<<< Updated upstream -name = "AMRS" -version = "0.1.0" -======= name = "arms" version = "0.1.0" dependencies = [ @@ -2136,4 +2132,3 @@ dependencies = [ "quote", "syn", ] ->>>>>>> Stashed changes diff --git a/src/client/client.rs b/src/client/client.rs index b53ed45..b4a1894 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -28,12 +28,28 @@ pub struct Client { } impl Client { - pub fn new(config: &Config) -> Self { + pub fn new(config: Config) -> Self { + let mut cfg = config; + cfg.finalize().expect("Invalid configuration"); + + let providers = cfg + .models + .iter() + .map(|m| { + let provider = m + .provider + .as_ref() + .expect("Model provider must be specified"); + + (m.id.clone(), provider::build_provider(provider, m)) + }) + .collect(); + Self { - config: config.clone(), + config: cfg.clone(), router_tracker: None, - providers: HashMap::new(), - router: router::build_router(&config.routing_mode, &config.models), + providers: providers, + router: router::build_router(cfg.routing_mode, cfg.models), } } @@ -43,15 +59,6 @@ impl Client { } } - pub fn build(&mut self) { - self.config.models.iter().for_each(|m| { - self.providers.insert( - m.id.clone(), - provider::build_provider(&m.provider.as_ref().unwrap(), m), - ); - }); - } - pub async fn create_response( &self, request: provider::ResponseRequest, diff --git a/src/config.rs b/src/config.rs index 3198bf7..2c38bd2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -41,7 +41,7 @@ pub struct ModelConfig { } impl ModelConfig { - pub fn new(id: ModelId) -> Self { + fn new(id: ModelId) -> Self { Self { base_url: None, provider: None, @@ -53,47 +53,43 @@ impl ModelConfig { } } - pub fn base_url(mut self, url: &str) -> Self { + pub fn with_base_url(mut self, url: &str) -> Self { self.base_url = Some(url.to_string()); self } - pub fn provider(mut self, provider: &str) -> Self { + pub fn with_provider(mut self, provider: &str) -> Self { self.provider = Some(provider.to_string()); self } - pub fn temperature(mut self, temperature: f32) -> Self { + pub fn with_temperature(mut self, temperature: f32) -> Self { self.temperature = Some(temperature); self } - pub fn max_output_tokens(mut self, max_output_tokens: usize) -> Self { + pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { self.max_output_tokens = Some(max_output_tokens); self } - pub fn weight(mut self, weight: i32) -> Self { + pub fn with_weight(mut self, weight: i32) -> Self { self.weight = weight; self } - - pub fn build(self) -> Result { - Ok(self) - } } // ------------------ Main Config ------------------ #[derive(Debug, Clone)] pub struct Config { // global configs for models, will be overridden by model-specific configs - pub base_url: Option, - pub provider: ProviderName, // "AMRS" by default - pub temperature: f32, // 0.8 by default - pub max_output_tokens: usize, // 1024 by default + pub(crate) base_url: Option, + pub(crate) provider: ProviderName, // "AMRS" by default + pub(crate) temperature: f32, // 0.8 by default + pub(crate) max_output_tokens: usize, // 1024 by default - pub routing_mode: RoutingMode, // Random by default - pub models: Vec, + pub(crate) routing_mode: RoutingMode, // Random by default + pub(crate) models: Vec, } impl Default for Config { @@ -110,48 +106,43 @@ impl Default for Config { } impl Config { - pub fn new() -> Self { - let cfg = Config::default(); - cfg - } - - pub fn base_url(mut self, url: &str) -> Self { + pub fn with_base_url(mut self, url: &str) -> Self { self.base_url = Some(url.to_string()); self } - pub fn provider(mut self, provider: &str) -> Self { + pub fn with_provider(mut self, provider: &str) -> Self { self.provider = provider.to_string(); self } - pub fn temperature(mut self, temperature: f32) -> Self { + pub fn with_temperature(mut self, temperature: f32) -> Self { self.temperature = temperature; self } - pub fn max_output_tokens(mut self, max_output_tokens: usize) -> Self { + pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { self.max_output_tokens = max_output_tokens; self } - pub fn routing_mode(mut self, mode: RoutingMode) -> Self { + pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self { self.routing_mode = mode; self } - pub fn add_model(mut self, model: ModelConfig) -> Self { + pub fn with_model(mut self, model: ModelConfig) -> Self { self.models.push(model); self } - pub fn build(mut self) -> Result { - self.set_defaults(); - self.validate()?; - Ok(self) + pub fn finalize(&mut self) -> Result<(), String> { + self.post_defaults(); + self.validate_model_config()?; + Ok(()) } - fn set_defaults(&mut self) { + fn post_defaults(&mut self) { for model in &mut self.models { let model_url_exist = model.base_url.is_some(); @@ -177,11 +168,6 @@ impl Config { } } - fn validate(&self) -> Result<(), String> { - self.validate_model_config()?; - Ok(()) - } - fn validate_model_config(&self) -> Result<(), String> { if self.models.is_empty() { return Err("At least one model must be configured.".to_string()); @@ -244,61 +230,58 @@ mod tests { from_filename(".env.test").ok(); // case 1: - let valid_simplest_models_cfg = Config::new() - .provider("OPENAI") - .add_model(ModelConfig::new("gpt-4".to_string()).build().unwrap()); - let res = valid_simplest_models_cfg.build(); + let mut valid_simplest_models_cfg = Config::default() + .with_provider("OPENAI") + .with_model(ModelConfig::new("gpt-4".to_string())); + let res = valid_simplest_models_cfg.finalize(); assert!(res.clone().is_ok()); assert!( - res.clone().unwrap().models[0].base_url + valid_simplest_models_cfg.models[0].base_url == Some("https://api.openai.com/v1".to_string()) ); - assert!(res.clone().unwrap().models[0].provider == Some("OPENAI".to_string())); - assert!(res.clone().unwrap().models[0].temperature == Some(0.8)); - assert!(res.clone().unwrap().models[0].max_output_tokens == Some(1024)); - assert!(res.clone().unwrap().models[0].weight == -1); + assert!(valid_simplest_models_cfg.models[0].provider == Some("OPENAI".to_string())); + assert!(valid_simplest_models_cfg.models[0].temperature == Some(0.8)); + assert!(valid_simplest_models_cfg.models[0].max_output_tokens == Some(1024)); + assert!(valid_simplest_models_cfg.models[0].weight == -1); // case 2: - let valid_cfg = Config::new() - .provider("OPENAI") - .add_model( - ModelConfig::new("gpt-3.5-turbo".to_string()) - .build() - .unwrap(), - ) - .add_model(ModelConfig::new("gpt-4".to_string()).build().unwrap()); - assert!(valid_cfg.build().is_ok()); + let mut valid_cfg = Config::default() + .with_provider("OPENAI") + .with_model(ModelConfig::new("gpt-3.5-turbo".to_string())) + .with_model(ModelConfig::new("gpt-4".to_string())); + assert!(valid_cfg.finalize().is_ok()); // case 3: - let invalid_cfg_with_no_api_key = Config::new() - .provider("unknown_provider") - .add_model(ModelConfig::new("some-model".to_string()).build().unwrap()); - assert!(invalid_cfg_with_no_api_key.build().is_err()); + let mut invalid_cfg_with_no_api_key = Config::default() + .with_provider("unknown_provider") + .with_model(ModelConfig::new("some-model".to_string())); + assert!(invalid_cfg_with_no_api_key.finalize().is_err()); // case 4: - let valid_cfg_with_customized_provider = Config::new() - .base_url("http://example.ai") - .max_output_tokens(2048) - .add_model( - ModelConfig::new("custom-model".to_string()) - .provider("FOO") - .build() - .unwrap(), - ); - let res = valid_cfg_with_customized_provider.build(); + let mut valid_cfg_with_customized_provider = Config::default() + .with_base_url("http://example.ai") + .with_max_output_tokens(2048) + .with_model(ModelConfig::new("custom-model".to_string()).with_provider("FOO")); + let res = valid_cfg_with_customized_provider.finalize(); assert!(res.is_ok()); assert_eq!( - res.clone().unwrap().models[0].base_url.as_ref().unwrap(), + valid_cfg_with_customized_provider.models[0] + .base_url + .as_ref() + .unwrap(), "http://example.ai" ); assert_eq!( - res.clone().unwrap().models[0].provider, + valid_cfg_with_customized_provider.models[0].provider, Some("FOO".to_string()) ); - assert_eq!(res.unwrap().models[0].max_output_tokens, Some(2048)); + assert_eq!( + valid_cfg_with_customized_provider.models[0].max_output_tokens, + Some(2048) + ); // case 5: - let invalid_empty_models_cfg = Config::new().provider("OPENAI"); - assert!(invalid_empty_models_cfg.build().is_err()); + let mut invalid_empty_models_cfg = Config::default().with_provider("OPENAI"); + assert!(invalid_empty_models_cfg.finalize().is_err()); } } diff --git a/src/router/random.rs b/src/router/random.rs index 9a8d602..0d0966b 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -9,10 +9,8 @@ pub struct RandomRouter { } impl RandomRouter { - pub fn new(model_ids: &[ModelId]) -> Self { - Self { - model_ids: model_ids.to_vec(), - } + pub fn new(model_ids: Vec) -> Self { + Self { model_ids } } } diff --git a/src/router/router.rs b/src/router/router.rs index 8aef69c..17c3708 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -6,11 +6,11 @@ use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; use crate::router::weight::WeightedRouter; -pub fn build_router(mode: &RoutingMode, models: &[ModelConfig]) -> Box { +pub fn build_router(mode: RoutingMode, models: Vec) -> Box { let model_ids: Vec = models.iter().map(|m| m.id.clone()).collect(); match mode { - RoutingMode::Random => Box::new(RandomRouter::new(&model_ids)), - RoutingMode::Weighted => Box::new(WeightedRouter::new(&model_ids)), + RoutingMode::Random => Box::new(RandomRouter::new(model_ids)), + RoutingMode::Weighted => Box::new(WeightedRouter::new(model_ids)), } } diff --git a/src/router/weight.rs b/src/router/weight.rs index a61fc01..a064b79 100644 --- a/src/router/weight.rs +++ b/src/router/weight.rs @@ -6,10 +6,8 @@ pub struct WeightedRouter { } impl WeightedRouter { - pub fn new(model_ids: &[ModelId]) -> Self { - Self { - model_ids: model_ids.to_vec(), - } + pub fn new(model_ids: Vec) -> Self { + Self { model_ids } } }