From a27ddab5350a4ed8b1570beee80addeea24c06e7 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 12:30:05 +0800 Subject: [PATCH 1/5] use builder mode in Config Signed-off-by: kerthcet --- .env.test | 1 - Cargo.lock | 1 + Cargo.toml | 1 + src/client/client.rs | 47 +++--- src/config.rs | 334 ++++++++++++++++++++++----------------- src/lib.rs | 3 +- src/provider/provider.rs | 28 +++- src/router/random.rs | 23 ++- src/router/router.rs | 2 +- src/router/weight.rs | 4 +- 10 files changed, 261 insertions(+), 183 deletions(-) diff --git a/.env.test b/.env.test index 6766c19..4fa7612 100644 --- a/.env.test +++ b/.env.test @@ -1,3 +1,2 @@ AMRS_API_KEY=your_amrs_api_key_here OPENAI_API_KEY=your_openai_api_key_here -FOO_API_KEY=your_foo_api_key_here diff --git a/Cargo.lock b/Cargo.lock index 1b9ec19..b4cca6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,7 @@ version = "0.1.0" dependencies = [ "async-openai", "async-trait", + "derive_builder", "dotenvy", "lazy_static", "rand 0.9.2", diff --git a/Cargo.toml b/Cargo.toml index 061a476..66ecc70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] } async-trait = "0.1.89" +derive_builder = "0.20.2" dotenvy = "0.15.7" lazy_static = "1.5.0" rand = "0.9.2" diff --git a/src/client/client.rs b/src/client/client.rs index b4a1894..b8199d7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -5,48 +5,24 @@ use crate::config::ModelId; use crate::provider::provider; use crate::router::router; -// ------------------ Chat Role ------------------ -#[derive(Debug, Clone)] -pub enum ChatRole { - User, - Assistant, - System, -} - -// ------------------ Message ------------------ -#[derive(Debug, Clone)] -pub struct TextMessage { - pub role: ChatRole, - pub content: String, -} - pub struct Client { - config: Config, router_tracker: Option, router: Box, providers: HashMap>, } impl Client { - pub fn new(config: Config) -> Self { - let mut cfg = config; - cfg.finalize().expect("Invalid configuration"); + pub fn new(config: &Config) -> Self { + let mut cfg = config.clone(); + cfg.populate(); let providers = cfg .models .iter() - .map(|m| { - let provider = m - .provider - .as_ref() - .expect("Model provider must be specified"); - - (m.id.clone(), provider::build_provider(provider, m)) - }) + .map(|m| (m.id.clone(), provider::build_provider(m))) .collect(); Self { - config: cfg.clone(), router_tracker: None, providers: providers, router: router::build_router(cfg.routing_mode, cfg.models), @@ -64,7 +40,20 @@ impl Client { request: provider::ResponseRequest, ) -> Result { let model_id = self.router.sample(&request); - let provider = self.providers.get(model_id).unwrap(); + let provider = self.providers.get(&model_id).unwrap(); provider.create_response(request).await } } + +// // how to write test for this? + +// #[cfg(test)] +// mod tests { +// use super::*; +// #[test] +// fn test_client_creation() { +// let config = Config::default(); +// let client = Client::new(config); +// assert!(client.providers.len() > 0); +// } +// } diff --git a/src/config.rs b/src/config.rs index 2c38bd2..276cd5c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; use std::env; +use derive_builder::Builder; use lazy_static::lazy_static; // ------------------ Provider ------------------ pub type ProviderName = String; -const AMRS_PROVIDER: &str = "AMRS"; +const OPENAI_PROVIDER: &str = "OPENAI"; lazy_static! { pub static ref PROVIDER_BASE_URLS: HashMap<&'static str, &'static str> = { @@ -28,121 +29,65 @@ pub enum RoutingMode { // ------------------ Model Config ------------------ pub type ModelId = String; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Builder)] +#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] pub struct ModelConfig { // model-specific configs, will override global configs if provided + #[builder(default = "None")] pub base_url: Option, + #[builder(default = "None")] pub provider: Option, + #[builder(default = "None")] pub temperature: Option, + #[builder(default = "None")] pub max_output_tokens: Option, pub id: ModelId, - pub weight: i32, // -1 if unused + #[builder(default=-1)] + pub weight: i32, } -impl ModelConfig { - fn new(id: ModelId) -> Self { - Self { - base_url: None, - provider: None, - temperature: None, - max_output_tokens: None, - - id: id, - weight: -1, +impl ModelConfigBuilder { + fn validate(&self) -> Result<(), String> { + if self.id.is_none() { + return Err("Model id must be provided.".to_string()); } + Ok(()) } +} - pub fn with_base_url(mut self, url: &str) -> Self { - self.base_url = Some(url.to_string()); - self - } - - pub fn with_provider(mut self, provider: &str) -> Self { - self.provider = Some(provider.to_string()); - self - } - - pub fn with_temperature(mut self, temperature: f32) -> Self { - self.temperature = Some(temperature); - self - } - - pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { - self.max_output_tokens = Some(max_output_tokens); - self - } - - pub fn with_weight(mut self, weight: i32) -> Self { - self.weight = weight; - self +impl ModelConfig { + pub fn builder() -> ModelConfigBuilder { + ModelConfigBuilder::default() } } // ------------------ Main Config ------------------ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Builder)] +#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] pub struct Config { // global configs for models, will be overridden by model-specific configs - pub(crate) base_url: Option, - pub(crate) provider: ProviderName, // "AMRS" by default - pub(crate) temperature: f32, // 0.8 by default - pub(crate) max_output_tokens: usize, // 1024 by default - - pub(crate) routing_mode: RoutingMode, // Random by default + #[builder(default = "https://api.openai.com/v1".to_string())] + pub(crate) base_url: String, + #[builder(default = "ProviderName::from(OPENAI_PROVIDER)")] + pub(crate) provider: ProviderName, + #[builder(default = "0.8")] + pub(crate) temperature: f32, + #[builder(default = "1024")] + pub(crate) max_output_tokens: usize, + + #[builder(default = "RoutingMode::Random")] + pub(crate) routing_mode: RoutingMode, + #[builder(default = "vec![]")] pub(crate) models: Vec, } -impl Default for Config { - fn default() -> Self { - Self { - base_url: None, - provider: AMRS_PROVIDER.to_string(), - temperature: 0.8, - max_output_tokens: 1024, - routing_mode: RoutingMode::Random, - models: vec![], - } - } -} - impl Config { - pub fn with_base_url(mut self, url: &str) -> Self { - self.base_url = Some(url.to_string()); - self - } - - pub fn with_provider(mut self, provider: &str) -> Self { - self.provider = provider.to_string(); - self - } - - pub fn with_temperature(mut self, temperature: f32) -> Self { - self.temperature = temperature; - self - } - - pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { - self.max_output_tokens = max_output_tokens; - self - } - - pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self { - self.routing_mode = mode; - self - } - - pub fn with_model(mut self, model: ModelConfig) -> Self { - self.models.push(model); - self - } - - pub fn finalize(&mut self) -> Result<(), String> { - self.post_defaults(); - self.validate_model_config()?; - Ok(()) + pub fn builder() -> ConfigBuilder { + ConfigBuilder::default() } - fn post_defaults(&mut self) { + pub fn populate(&mut self) -> &mut Self { for model in &mut self.models { let model_url_exist = model.base_url.is_some(); @@ -156,8 +101,8 @@ impl Config { model.base_url = Some(PROVIDER_BASE_URLS[model.provider.as_ref().unwrap().as_str()].to_string()); } - if !model_url_exist && self.base_url.is_some() { - model.base_url = self.base_url.clone(); + if !model_url_exist { + model.base_url = Some(self.base_url.clone()); } if model.temperature.is_none() { model.temperature = Some(self.temperature); @@ -166,19 +111,28 @@ impl Config { model.max_output_tokens = Some(self.max_output_tokens); } } + self + } +} + +impl ConfigBuilder { + pub fn model(&mut self, model: ModelConfig) -> &mut Self { + let mut models = self.models.clone().unwrap_or_default(); + models.push(model); + self.models = Some(models); + self } - fn validate_model_config(&self) -> Result<(), String> { - if self.models.is_empty() { + fn validate(&self) -> Result<(), String> { + if self.models.is_none() || self.models.as_ref().unwrap().is_empty() { return Err("At least one model must be configured.".to_string()); } - for model in &self.models { - if model.base_url.is_none() && self.base_url.is_none() { - return Err(format!("Model '{}' base_url is not provided.", model.id)); - } - - if self.routing_mode == RoutingMode::Weighted && model.weight <= 0 { + for model in self.models.as_ref().unwrap() { + if self.routing_mode.is_some() + && self.routing_mode.as_ref().unwrap() == &RoutingMode::Weighted + && model.weight <= 0 + { return Err(format!( "Model '{}' weight must be non-negative in Weighted routing mode.", model.id @@ -203,6 +157,7 @@ impl Config { } } + // check the existence of API key in environment variables if let Some(provider) = &model.provider { let env_var = format!("{}_API_KEY", provider.to_uppercase()); if env::var(&env_var).is_err() { @@ -212,6 +167,25 @@ impl Config { env_var )); } + } else { + // default is called after validation. + let env_var = format!( + "{}_API_KEY", + self.provider + .as_ref() + .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) + .to_uppercase() + ); + if env::var(&env_var).is_err() { + return Err(format!( + "API key for provider '{}' not found in environment variable '{}'", + self.provider + .as_ref() + .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) + .to_uppercase(), + env_var + )); + } } } @@ -230,58 +204,130 @@ mod tests { from_filename(".env.test").ok(); // case 1: - let mut valid_simplest_models_cfg = Config::default() - .with_provider("OPENAI") - .with_model(ModelConfig::new("gpt-4".to_string())); - let res = valid_simplest_models_cfg.finalize(); - assert!(res.clone().is_ok()); + let valid_simplest_models_cfg = Config::builder() + .model( + ModelConfig::builder() + .id("gpt-4".to_string()) + .build() + .unwrap(), + ) + .build(); + assert!(valid_simplest_models_cfg.is_ok()); + assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == OPENAI_PROVIDER); assert!( - valid_simplest_models_cfg.models[0].base_url - == Some("https://api.openai.com/v1".to_string()) + valid_simplest_models_cfg.as_ref().unwrap().base_url == "https://api.openai.com/v1" + ); + assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8); + assert!( + valid_simplest_models_cfg + .as_ref() + .unwrap() + .max_output_tokens + == 1024 ); - assert!(valid_simplest_models_cfg.models[0].provider == Some("OPENAI".to_string())); - assert!(valid_simplest_models_cfg.models[0].temperature == Some(0.8)); - assert!(valid_simplest_models_cfg.models[0].max_output_tokens == Some(1024)); - assert!(valid_simplest_models_cfg.models[0].weight == -1); + assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1); // case 2: - let mut valid_cfg = Config::default() - .with_provider("OPENAI") - .with_model(ModelConfig::new("gpt-3.5-turbo".to_string())) - .with_model(ModelConfig::new("gpt-4".to_string())); - assert!(valid_cfg.finalize().is_ok()); + let valid_cfg = Config::builder() + .models(vec![ + ModelConfig::builder() + .id("gpt-3.5-turbo".to_string()) + .build() + .unwrap(), + ModelConfig::builder() + .id("gpt-4".to_string()) + .build() + .unwrap(), + ]) + .build(); + assert!(valid_cfg.is_ok()); + assert!(valid_cfg.as_ref().unwrap().models.len() == 2); // case 3: - let mut invalid_cfg_with_no_api_key = Config::default() - .with_provider("unknown_provider") - .with_model(ModelConfig::new("some-model".to_string())); - assert!(invalid_cfg_with_no_api_key.finalize().is_err()); + let invalid_cfg_with_no_api_key = Config::builder() + .model( + ModelConfig::builder() + .id("some-model".to_string()) + .build() + .unwrap(), + ) + .provider("unknown_provider".to_string()) + .build(); + assert!(invalid_cfg_with_no_api_key.is_err()); // case 4: - let mut valid_cfg_with_customized_provider = Config::default() - .with_base_url("http://example.ai") - .with_max_output_tokens(2048) - .with_model(ModelConfig::new("custom-model".to_string()).with_provider("FOO")); - let res = valid_cfg_with_customized_provider.finalize(); - assert!(res.is_ok()); - assert_eq!( - valid_cfg_with_customized_provider.models[0] - .base_url - .as_ref() - .unwrap(), - "http://example.ai" - ); - assert_eq!( - valid_cfg_with_customized_provider.models[0].provider, - Some("FOO".to_string()) - ); - assert_eq!( - valid_cfg_with_customized_provider.models[0].max_output_tokens, - Some(2048) - ); + // AMRS_API_KEY is set in .env.test already. + let valid_cfg_with_customized_provider = Config::builder() + .base_url("http://example.ai".to_string()) + .max_output_tokens(2048) + .model( + ModelConfig::builder() + .id("custom-model".to_string()) + .provider(Some("AMRS".to_string())) + .build() + .unwrap(), + ) + .build(); + assert!(valid_cfg_with_customized_provider.is_ok()); // case 5: - let mut invalid_empty_models_cfg = Config::default().with_provider("OPENAI"); - assert!(invalid_empty_models_cfg.finalize().is_err()); + let invalid_empty_models_cfg = Config::builder().build(); + assert!(invalid_empty_models_cfg.is_err()); + + // case 6: + print!("validating invalid empty model id config"); + let invalid_empty_model_id_cfg = ModelConfig::builder().build(); + assert!(invalid_empty_model_id_cfg.is_err()); + } + + #[test] + fn test_populate_config() { + let mut valid_cfg = Config::builder() + .temperature(0.5) + .max_output_tokens(1500) + .model( + ModelConfig::builder() + .id("model-1".to_string()) + .build() + .unwrap(), + ) + .build(); + valid_cfg.as_mut().unwrap().populate(); + + assert!(valid_cfg.is_ok()); + assert!(valid_cfg.as_ref().unwrap().models.len() == 1); + assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5)); + assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500)); + assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string())); + assert!( + valid_cfg.as_ref().unwrap().models[0].base_url + == Some("https://api.openai.com/v1".to_string()) + ); + assert!(valid_cfg.as_ref().unwrap().models[0].weight == -1); + + let mut valid_specified_cfg = Config::builder() + .provider("DEEPINFRA".to_string()) + .base_url("http://custom-api.ai".to_string()) + .model( + ModelConfig::builder() + .id("model-2".to_string()) + .build() + .unwrap(), + ) + .build(); + valid_specified_cfg.as_mut().unwrap().populate(); + + assert!(valid_specified_cfg.is_ok()); + assert!(valid_specified_cfg.as_ref().unwrap().provider == "DEEPINFRA".to_string()); + assert!( + valid_specified_cfg.as_ref().unwrap().models[0].base_url + == Some("http://custom-api.ai".to_string()) + ); } } diff --git a/src/lib.rs b/src/lib.rs index 4524963..6d47760 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ mod router { pub mod router; mod weight; } -mod config; mod client { pub mod client; } @@ -12,5 +11,5 @@ mod provider { pub mod provider; } +pub mod config; pub use crate::client::client::Client; -pub use crate::config::Config; diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 76fa461..65fe3dc 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -2,14 +2,15 @@ use async_openai::error::OpenAIError; use async_openai::types::responses::{CreateResponse as OpenAIRequest, Response as OpenAIResponse}; use async_trait::async_trait; -use crate::config::{ModelConfig, ProviderName}; +use crate::config::ModelConfig; use crate::provider::openai::OpenAIProvider; pub type ResponseRequest = OpenAIRequest; pub type ResponseResult = OpenAIResponse; pub type APIError = OpenAIError; -pub fn build_provider(provider: &ProviderName, config: &ModelConfig) -> Box { +pub fn build_provider(config: &ModelConfig) -> Box { + let provider = config.provider.as_ref().unwrap(); match provider.as_str() { "openai" => Box::new(OpenAIProvider::new(config).build()), _ => panic!("Unsupported provider: {}", provider), @@ -20,3 +21,26 @@ pub fn build_provider(provider: &ProviderName, config: &ModelConfig) -> Box Result; } + +// // test +// #[cfg(test)] +// mod tests { +// use super::*; +// fn test_build_provider() { +// struct TestCase { +// name: &'static str, +// config: ModelConfig, +// expect_provider_type: &'static str, +// error: bool, +// } + +// let cases = vec![ +// TestCase { +// name: "OpenAI Provider", +// config: ModelConfig::new("test-model").with_provider("openai"), +// expect_provider_type: "OpenAIProvider", +// error: false, +// }, +// // Add more test cases as needed +// ]; +// } diff --git a/src/router/random.rs b/src/router/random.rs index 0d0966b..5e553ec 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -15,9 +15,28 @@ impl RandomRouter { } impl Router for RandomRouter { - fn sample(&self, _input: &ResponseRequest) -> &ModelId { + fn sample(&self, _input: &ResponseRequest) -> ModelId { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_ids.len()); - &self.model_ids[idx] + self.model_ids[idx].clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_random_router_sampling() { + let model_ids = vec!["model_a".to_string(), "model_b".to_string()]; + let router = RandomRouter::new(model_ids.clone()); + let mut counts = std::collections::HashMap::new(); + for _ in 0..1000 { + let sampled_id = router.sample(&ResponseRequest::default()); + *counts.entry(sampled_id.clone()).or_insert(0) += 1; + } + assert!(counts.len() == model_ids.len()); + for count in counts.values() { + assert!(*count > 0); + } } } diff --git a/src/router/router.rs b/src/router/router.rs index 17c3708..f9171a0 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -15,7 +15,7 @@ pub fn build_router(mode: RoutingMode, models: Vec) -> Box &ModelId; + fn sample(&self, input: &ResponseRequest) -> ModelId; } pub struct RouterTracker { diff --git a/src/router/weight.rs b/src/router/weight.rs index a064b79..9cc6956 100644 --- a/src/router/weight.rs +++ b/src/router/weight.rs @@ -12,8 +12,8 @@ impl WeightedRouter { } impl Router for WeightedRouter { - fn sample(&self, _input: &ResponseRequest) -> &ModelId { + fn sample(&self, _input: &ResponseRequest) -> ModelId { // TODO: Implement weighted sampling logic - return &self.model_ids[0]; + return self.model_ids[0].clone(); } } From f1923caad3fa559f25f9e11b7989ff51beb4b73a Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 12:35:10 +0800 Subject: [PATCH 2/5] rename Signed-off-by: kerthcet --- src/client/client.rs | 4 ++-- src/provider/provider.rs | 2 +- src/router/router.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index b8199d7..0672dfe 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -19,13 +19,13 @@ impl Client { let providers = cfg .models .iter() - .map(|m| (m.id.clone(), provider::build_provider(m))) + .map(|m| (m.id.clone(), provider::construct_provider(m))) .collect(); Self { router_tracker: None, providers: providers, - router: router::build_router(cfg.routing_mode, cfg.models), + router: router::construct_router(cfg.routing_mode, cfg.models), } } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 65fe3dc..ec9febf 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -9,7 +9,7 @@ pub type ResponseRequest = OpenAIRequest; pub type ResponseResult = OpenAIResponse; pub type APIError = OpenAIError; -pub fn build_provider(config: &ModelConfig) -> Box { +pub fn construct_provider(config: &ModelConfig) -> Box { let provider = config.provider.as_ref().unwrap(); match provider.as_str() { "openai" => Box::new(OpenAIProvider::new(config).build()), diff --git a/src/router/router.rs b/src/router/router.rs index f9171a0..3bb3dc4 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -6,7 +6,7 @@ use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; use crate::router::weight::WeightedRouter; -pub fn build_router(mode: RoutingMode, models: Vec) -> Box { +pub fn construct_router(mode: RoutingMode, models: Vec) -> Box { let model_ids: Vec = models.iter().map(|m| m.id.clone()).collect(); match mode { RoutingMode::Random => Box::new(RandomRouter::new(model_ids)), From 847c41080c30246bba45b9557c1d1bddb19f69bf Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 14:14:55 +0800 Subject: [PATCH 3/5] add tests Signed-off-by: kerthcet --- src/client/client.rs | 104 +++++++++++++++++++++++++++++++++------ src/config.rs | 1 + src/provider/openai.rs | 4 ++ src/provider/provider.rs | 81 +++++++++++++++++++++--------- src/router/random.rs | 5 ++ src/router/router.rs | 29 +++++++++++ src/router/weight.rs | 4 ++ 7 files changed, 192 insertions(+), 36 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 0672dfe..1e403ea 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; -use crate::config::Config; -use crate::config::ModelId; +use crate::config::{Config, ModelConfig, ModelId, RoutingMode}; use crate::provider::provider; use crate::router::router; @@ -45,15 +44,92 @@ impl Client { } } -// // how to write test for this? - -// #[cfg(test)] -// mod tests { -// use super::*; -// #[test] -// fn test_client_creation() { -// let config = Config::default(); -// let client = Client::new(config); -// assert!(client.providers.len() > 0); -// } -// } +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_client_new() { + struct TestCase { + name: &'static str, + config: Config, + expected_router_name: &'static str, + enabled_tracker: bool, + } + + let cases = vec![ + TestCase { + name: "Random Router", + config: Config::builder() + .routing_mode(RoutingMode::Random) + .models(vec![ + crate::config::ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + crate::config::ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "RandomRouter", + enabled_tracker: false, + }, + TestCase { + name: "router tracker enabled", + config: Config::builder() + .routing_mode(RoutingMode::Weighted) + .models(vec![ + ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(1) + .build() + .unwrap(), + ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(3) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "WeightedRouter", + enabled_tracker: true, + }, + ]; + + for case in cases { + let mut client = Client::new(&case.config); + if case.enabled_tracker { + client.enable_router_tracker(); + } + assert_eq!( + client.router.name(), + case.expected_router_name, + "Test case '{}' failed", + case.name + ); + assert_eq!( + client.router_tracker.is_some(), + case.enabled_tracker, + "Test case '{}' failed on router tracker state", + case.name + ); + assert_eq!( + client.providers.len(), + case.config.models.len(), + "Test case '{}' failed on providers count", + case.name + ); + } + } +} diff --git a/src/config.rs b/src/config.rs index 276cd5c..d09fb61 100644 --- a/src/config.rs +++ b/src/config.rs @@ -87,6 +87,7 @@ impl Config { ConfigBuilder::default() } + // populate will fill in the missing model-specific configs with global configs. pub fn populate(&mut self) -> &mut Self { for model in &mut self.models { let model_url_exist = model.base_url.is_some(); diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 5ef82fc..601013a 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -50,6 +50,10 @@ impl OpenAIProvider { #[async_trait] impl Provider for OpenAIProvider { + fn name(&self) -> &'static str { + "OpenAIProvider" + } + async fn create_response(&self, request: ResponseRequest) -> Result { let client = self.client.as_ref().unwrap(); client.responses().create(request).await diff --git a/src/provider/provider.rs b/src/provider/provider.rs index ec9febf..37df4a8 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -19,28 +19,65 @@ pub fn construct_provider(config: &ModelConfig) -> Box { #[async_trait] pub trait Provider: Send + Sync { + fn name(&self) -> &'static str; async fn create_response(&self, request: ResponseRequest) -> Result; } -// // test -// #[cfg(test)] -// mod tests { -// use super::*; -// fn test_build_provider() { -// struct TestCase { -// name: &'static str, -// config: ModelConfig, -// expect_provider_type: &'static str, -// error: bool, -// } - -// let cases = vec![ -// TestCase { -// name: "OpenAI Provider", -// config: ModelConfig::new("test-model").with_provider("openai"), -// expect_provider_type: "OpenAIProvider", -// error: false, -// }, -// // Add more test cases as needed -// ]; -// } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_construction() { + struct TestCase { + name: &'static str, + config: ModelConfig, + expect_provider_type: &'static str, + } + + let cases = vec![ + TestCase { + name: "OpenAI Provider", + config: ModelConfig::builder() + .id("test-model".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + expect_provider_type: "OpenAIProvider", + }, + TestCase { + name: "Unsupported Provider", + config: ModelConfig::builder() + .id("test-model".to_string()) + .provider(Some("unsupported".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + expect_provider_type: "", + }, + ]; + + for case in cases { + if case.expect_provider_type.is_empty() { + let result = std::panic::catch_unwind(|| { + construct_provider(&case.config); + }); + assert!( + result.is_err(), + "Test case '{}' did not panic as expected", + case.name + ); + } else { + let provider = construct_provider(&case.config); + assert!( + provider.name() == case.expect_provider_type, + "Test case '{}': expected provider type '{}', got '{}'", + case.name, + case.expect_provider_type, + provider.name() + ); + } + } + } +} diff --git a/src/router/random.rs b/src/router/random.rs index 5e553ec..01d7d7a 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -15,6 +15,10 @@ impl RandomRouter { } impl Router for RandomRouter { + fn name(&self) -> &'static str { + "RandomRouter" + } + fn sample(&self, _input: &ResponseRequest) -> ModelId { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_ids.len()); @@ -25,6 +29,7 @@ impl Router for RandomRouter { #[cfg(test)] mod tests { use super::*; + #[test] fn test_random_router_sampling() { let model_ids = vec!["model_a".to_string(), "model_b".to_string()]; diff --git a/src/router/router.rs b/src/router/router.rs index 3bb3dc4..9aca04b 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -15,6 +15,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; fn sample(&self, input: &ResponseRequest) -> ModelId; } @@ -33,3 +34,31 @@ impl RouterTracker { } } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_router_construction() { + let model_configs = vec![ + ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ]; + + let random_router = construct_router(RoutingMode::Random, model_configs.clone()); + assert_eq!(random_router.name(), "RandomRouter"); + + let weighted_router = construct_router(RoutingMode::Weighted, model_configs.clone()); + assert_eq!(weighted_router.name(), "WeightedRouter"); + } +} diff --git a/src/router/weight.rs b/src/router/weight.rs index 9cc6956..5f1ed1a 100644 --- a/src/router/weight.rs +++ b/src/router/weight.rs @@ -12,6 +12,10 @@ impl WeightedRouter { } impl Router for WeightedRouter { + fn name(&self) -> &'static str { + "WeightedRouter" + } + fn sample(&self, _input: &ResponseRequest) -> ModelId { // TODO: Implement weighted sampling logic return self.model_ids[0].clone(); From 023aa9dca4450c5368dc8fc870b07902e5f29356 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 14:21:26 +0800 Subject: [PATCH 4/5] add tests Signed-off-by: kerthcet --- src/client/client.rs | 27 ++++++++++++++++++++------- src/provider/provider.rs | 4 ++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 1e403ea..29632b2 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -58,51 +58,64 @@ mod tests { let cases = vec![ TestCase { - name: "Random Router", + name: "basic config", config: Config::builder() - .routing_mode(RoutingMode::Random) + .models(vec![ + ModelConfig::builder() + .id("model_c".to_string()) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "RandomRouter", + enabled_tracker: false, + }, + TestCase { + name: "weighted router", + config: Config::builder() + .routing_mode(RoutingMode::Weighted) .models(vec![ crate::config::ModelConfig::builder() .id("model_a".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(1) .build() .unwrap(), crate::config::ModelConfig::builder() .id("model_b".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(3) .build() .unwrap(), ]) .build() .unwrap(), - expected_router_name: "RandomRouter", + expected_router_name: "WeightedRouter", enabled_tracker: false, }, TestCase { name: "router tracker enabled", config: Config::builder() - .routing_mode(RoutingMode::Weighted) .models(vec![ ModelConfig::builder() .id("model_a".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) - .weight(1) .build() .unwrap(), ModelConfig::builder() .id("model_b".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) - .weight(3) .build() .unwrap(), ]) .build() .unwrap(), - expected_router_name: "WeightedRouter", + expected_router_name: "RandomRouter", enabled_tracker: true, }, ]; diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 37df4a8..0f4eb44 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -11,8 +11,8 @@ pub type APIError = OpenAIError; pub fn construct_provider(config: &ModelConfig) -> Box { let provider = config.provider.as_ref().unwrap(); - match provider.as_str() { - "openai" => Box::new(OpenAIProvider::new(config).build()), + match provider.to_uppercase().as_ref() { + "OPENAI" => Box::new(OpenAIProvider::new(config).build()), _ => panic!("Unsupported provider: {}", provider), } } From 0531048fa0077e301142a502f2e6471e6c9cca5c Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 14:23:06 +0800 Subject: [PATCH 5/5] Use value rather than ref for config Signed-off-by: kerthcet --- src/client/client.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 29632b2..82f7a10 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -11,7 +11,7 @@ pub struct Client { } impl Client { - pub fn new(config: &Config) -> Self { + pub fn new(config: Config) -> Self { let mut cfg = config.clone(); cfg.populate(); @@ -121,7 +121,7 @@ mod tests { ]; for case in cases { - let mut client = Client::new(&case.config); + let mut client = Client::new(case.config.clone()); if case.enabled_tracker { client.enable_router_tracker(); }