diff --git a/.env.test b/.env.test index 6766c19..4fa7612 100644 --- a/.env.test +++ b/.env.test @@ -1,3 +1,2 @@ AMRS_API_KEY=your_amrs_api_key_here OPENAI_API_KEY=your_openai_api_key_here -FOO_API_KEY=your_foo_api_key_here diff --git a/Cargo.lock b/Cargo.lock index 1b9ec19..b4cca6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,7 @@ version = "0.1.0" dependencies = [ "async-openai", "async-trait", + "derive_builder", "dotenvy", "lazy_static", "rand 0.9.2", diff --git a/Cargo.toml b/Cargo.toml index 061a476..66ecc70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] } async-trait = "0.1.89" +derive_builder = "0.20.2" dotenvy = "0.15.7" lazy_static = "1.5.0" rand = "0.9.2" diff --git a/src/client/client.rs b/src/client/client.rs index b4a1894..82f7a10 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,27 +1,10 @@ use std::collections::HashMap; -use crate::config::Config; -use crate::config::ModelId; +use crate::config::{Config, ModelConfig, ModelId, RoutingMode}; use crate::provider::provider; use crate::router::router; -// ------------------ Chat Role ------------------ -#[derive(Debug, Clone)] -pub enum ChatRole { - User, - Assistant, - System, -} - -// ------------------ Message ------------------ -#[derive(Debug, Clone)] -pub struct TextMessage { - pub role: ChatRole, - pub content: String, -} - pub struct Client { - config: Config, router_tracker: Option, router: Box, providers: HashMap>, @@ -29,27 +12,19 @@ pub struct Client { impl Client { pub fn new(config: Config) -> Self { - let mut cfg = config; - cfg.finalize().expect("Invalid configuration"); + let mut cfg = config.clone(); + cfg.populate(); let providers = cfg .models .iter() - .map(|m| { - let provider = m - .provider - .as_ref() - .expect("Model provider must be specified"); - - (m.id.clone(), provider::build_provider(provider, m)) - }) + .map(|m| (m.id.clone(), provider::construct_provider(m))) .collect(); Self { - config: cfg.clone(), router_tracker: None, providers: providers, - router: router::build_router(cfg.routing_mode, cfg.models), + router: router::construct_router(cfg.routing_mode, cfg.models), } } @@ -64,7 +39,110 @@ impl Client { request: provider::ResponseRequest, ) -> Result { let model_id = self.router.sample(&request); - let provider = self.providers.get(model_id).unwrap(); + let provider = self.providers.get(&model_id).unwrap(); provider.create_response(request).await } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_client_new() { + struct TestCase { + name: &'static str, + config: Config, + expected_router_name: &'static str, + enabled_tracker: bool, + } + + let cases = vec![ + TestCase { + name: "basic config", + config: Config::builder() + .models(vec![ + ModelConfig::builder() + .id("model_c".to_string()) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "RandomRouter", + enabled_tracker: false, + }, + TestCase { + name: "weighted router", + config: Config::builder() + .routing_mode(RoutingMode::Weighted) + .models(vec![ + crate::config::ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(1) + .build() + .unwrap(), + crate::config::ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .weight(3) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "WeightedRouter", + enabled_tracker: false, + }, + TestCase { + name: "router tracker enabled", + config: Config::builder() + .models(vec![ + ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ]) + .build() + .unwrap(), + expected_router_name: "RandomRouter", + enabled_tracker: true, + }, + ]; + + for case in cases { + let mut client = Client::new(case.config.clone()); + if case.enabled_tracker { + client.enable_router_tracker(); + } + assert_eq!( + client.router.name(), + case.expected_router_name, + "Test case '{}' failed", + case.name + ); + assert_eq!( + client.router_tracker.is_some(), + case.enabled_tracker, + "Test case '{}' failed on router tracker state", + case.name + ); + assert_eq!( + client.providers.len(), + case.config.models.len(), + "Test case '{}' failed on providers count", + case.name + ); + } + } +} diff --git a/src/config.rs b/src/config.rs index 2c38bd2..d09fb61 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; use std::env; +use derive_builder::Builder; use lazy_static::lazy_static; // ------------------ Provider ------------------ pub type ProviderName = String; -const AMRS_PROVIDER: &str = "AMRS"; +const OPENAI_PROVIDER: &str = "OPENAI"; lazy_static! { pub static ref PROVIDER_BASE_URLS: HashMap<&'static str, &'static str> = { @@ -28,121 +29,66 @@ pub enum RoutingMode { // ------------------ Model Config ------------------ pub type ModelId = String; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Builder)] +#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] pub struct ModelConfig { // model-specific configs, will override global configs if provided + #[builder(default = "None")] pub base_url: Option, + #[builder(default = "None")] pub provider: Option, + #[builder(default = "None")] pub temperature: Option, + #[builder(default = "None")] pub max_output_tokens: Option, pub id: ModelId, - pub weight: i32, // -1 if unused + #[builder(default=-1)] + pub weight: i32, } -impl ModelConfig { - fn new(id: ModelId) -> Self { - Self { - base_url: None, - provider: None, - temperature: None, - max_output_tokens: None, - - id: id, - weight: -1, +impl ModelConfigBuilder { + fn validate(&self) -> Result<(), String> { + if self.id.is_none() { + return Err("Model id must be provided.".to_string()); } + Ok(()) } +} - pub fn with_base_url(mut self, url: &str) -> Self { - self.base_url = Some(url.to_string()); - self - } - - pub fn with_provider(mut self, provider: &str) -> Self { - self.provider = Some(provider.to_string()); - self - } - - pub fn with_temperature(mut self, temperature: f32) -> Self { - self.temperature = Some(temperature); - self - } - - pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { - self.max_output_tokens = Some(max_output_tokens); - self - } - - pub fn with_weight(mut self, weight: i32) -> Self { - self.weight = weight; - self +impl ModelConfig { + pub fn builder() -> ModelConfigBuilder { + ModelConfigBuilder::default() } } // ------------------ Main Config ------------------ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Builder)] +#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] pub struct Config { // global configs for models, will be overridden by model-specific configs - pub(crate) base_url: Option, - pub(crate) provider: ProviderName, // "AMRS" by default - pub(crate) temperature: f32, // 0.8 by default - pub(crate) max_output_tokens: usize, // 1024 by default - - pub(crate) routing_mode: RoutingMode, // Random by default + #[builder(default = "https://api.openai.com/v1".to_string())] + pub(crate) base_url: String, + #[builder(default = "ProviderName::from(OPENAI_PROVIDER)")] + pub(crate) provider: ProviderName, + #[builder(default = "0.8")] + pub(crate) temperature: f32, + #[builder(default = "1024")] + pub(crate) max_output_tokens: usize, + + #[builder(default = "RoutingMode::Random")] + pub(crate) routing_mode: RoutingMode, + #[builder(default = "vec![]")] pub(crate) models: Vec, } -impl Default for Config { - fn default() -> Self { - Self { - base_url: None, - provider: AMRS_PROVIDER.to_string(), - temperature: 0.8, - max_output_tokens: 1024, - routing_mode: RoutingMode::Random, - models: vec![], - } - } -} - impl Config { - pub fn with_base_url(mut self, url: &str) -> Self { - self.base_url = Some(url.to_string()); - self - } - - pub fn with_provider(mut self, provider: &str) -> Self { - self.provider = provider.to_string(); - self - } - - pub fn with_temperature(mut self, temperature: f32) -> Self { - self.temperature = temperature; - self - } - - pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self { - self.max_output_tokens = max_output_tokens; - self - } - - pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self { - self.routing_mode = mode; - self - } - - pub fn with_model(mut self, model: ModelConfig) -> Self { - self.models.push(model); - self - } - - pub fn finalize(&mut self) -> Result<(), String> { - self.post_defaults(); - self.validate_model_config()?; - Ok(()) + pub fn builder() -> ConfigBuilder { + ConfigBuilder::default() } - fn post_defaults(&mut self) { + // populate will fill in the missing model-specific configs with global configs. + pub fn populate(&mut self) -> &mut Self { for model in &mut self.models { let model_url_exist = model.base_url.is_some(); @@ -156,8 +102,8 @@ impl Config { model.base_url = Some(PROVIDER_BASE_URLS[model.provider.as_ref().unwrap().as_str()].to_string()); } - if !model_url_exist && self.base_url.is_some() { - model.base_url = self.base_url.clone(); + if !model_url_exist { + model.base_url = Some(self.base_url.clone()); } if model.temperature.is_none() { model.temperature = Some(self.temperature); @@ -166,19 +112,28 @@ impl Config { model.max_output_tokens = Some(self.max_output_tokens); } } + self + } +} + +impl ConfigBuilder { + pub fn model(&mut self, model: ModelConfig) -> &mut Self { + let mut models = self.models.clone().unwrap_or_default(); + models.push(model); + self.models = Some(models); + self } - fn validate_model_config(&self) -> Result<(), String> { - if self.models.is_empty() { + fn validate(&self) -> Result<(), String> { + if self.models.is_none() || self.models.as_ref().unwrap().is_empty() { return Err("At least one model must be configured.".to_string()); } - for model in &self.models { - if model.base_url.is_none() && self.base_url.is_none() { - return Err(format!("Model '{}' base_url is not provided.", model.id)); - } - - if self.routing_mode == RoutingMode::Weighted && model.weight <= 0 { + for model in self.models.as_ref().unwrap() { + if self.routing_mode.is_some() + && self.routing_mode.as_ref().unwrap() == &RoutingMode::Weighted + && model.weight <= 0 + { return Err(format!( "Model '{}' weight must be non-negative in Weighted routing mode.", model.id @@ -203,6 +158,7 @@ impl Config { } } + // check the existence of API key in environment variables if let Some(provider) = &model.provider { let env_var = format!("{}_API_KEY", provider.to_uppercase()); if env::var(&env_var).is_err() { @@ -212,6 +168,25 @@ impl Config { env_var )); } + } else { + // default is called after validation. + let env_var = format!( + "{}_API_KEY", + self.provider + .as_ref() + .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) + .to_uppercase() + ); + if env::var(&env_var).is_err() { + return Err(format!( + "API key for provider '{}' not found in environment variable '{}'", + self.provider + .as_ref() + .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) + .to_uppercase(), + env_var + )); + } } } @@ -230,58 +205,130 @@ mod tests { from_filename(".env.test").ok(); // case 1: - let mut valid_simplest_models_cfg = Config::default() - .with_provider("OPENAI") - .with_model(ModelConfig::new("gpt-4".to_string())); - let res = valid_simplest_models_cfg.finalize(); - assert!(res.clone().is_ok()); + let valid_simplest_models_cfg = Config::builder() + .model( + ModelConfig::builder() + .id("gpt-4".to_string()) + .build() + .unwrap(), + ) + .build(); + assert!(valid_simplest_models_cfg.is_ok()); + assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == OPENAI_PROVIDER); assert!( - valid_simplest_models_cfg.models[0].base_url - == Some("https://api.openai.com/v1".to_string()) + valid_simplest_models_cfg.as_ref().unwrap().base_url == "https://api.openai.com/v1" + ); + assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8); + assert!( + valid_simplest_models_cfg + .as_ref() + .unwrap() + .max_output_tokens + == 1024 ); - assert!(valid_simplest_models_cfg.models[0].provider == Some("OPENAI".to_string())); - assert!(valid_simplest_models_cfg.models[0].temperature == Some(0.8)); - assert!(valid_simplest_models_cfg.models[0].max_output_tokens == Some(1024)); - assert!(valid_simplest_models_cfg.models[0].weight == -1); + assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None); + assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1); // case 2: - let mut valid_cfg = Config::default() - .with_provider("OPENAI") - .with_model(ModelConfig::new("gpt-3.5-turbo".to_string())) - .with_model(ModelConfig::new("gpt-4".to_string())); - assert!(valid_cfg.finalize().is_ok()); + let valid_cfg = Config::builder() + .models(vec![ + ModelConfig::builder() + .id("gpt-3.5-turbo".to_string()) + .build() + .unwrap(), + ModelConfig::builder() + .id("gpt-4".to_string()) + .build() + .unwrap(), + ]) + .build(); + assert!(valid_cfg.is_ok()); + assert!(valid_cfg.as_ref().unwrap().models.len() == 2); // case 3: - let mut invalid_cfg_with_no_api_key = Config::default() - .with_provider("unknown_provider") - .with_model(ModelConfig::new("some-model".to_string())); - assert!(invalid_cfg_with_no_api_key.finalize().is_err()); + let invalid_cfg_with_no_api_key = Config::builder() + .model( + ModelConfig::builder() + .id("some-model".to_string()) + .build() + .unwrap(), + ) + .provider("unknown_provider".to_string()) + .build(); + assert!(invalid_cfg_with_no_api_key.is_err()); // case 4: - let mut valid_cfg_with_customized_provider = Config::default() - .with_base_url("http://example.ai") - .with_max_output_tokens(2048) - .with_model(ModelConfig::new("custom-model".to_string()).with_provider("FOO")); - let res = valid_cfg_with_customized_provider.finalize(); - assert!(res.is_ok()); - assert_eq!( - valid_cfg_with_customized_provider.models[0] - .base_url - .as_ref() - .unwrap(), - "http://example.ai" - ); - assert_eq!( - valid_cfg_with_customized_provider.models[0].provider, - Some("FOO".to_string()) - ); - assert_eq!( - valid_cfg_with_customized_provider.models[0].max_output_tokens, - Some(2048) - ); + // AMRS_API_KEY is set in .env.test already. + let valid_cfg_with_customized_provider = Config::builder() + .base_url("http://example.ai".to_string()) + .max_output_tokens(2048) + .model( + ModelConfig::builder() + .id("custom-model".to_string()) + .provider(Some("AMRS".to_string())) + .build() + .unwrap(), + ) + .build(); + assert!(valid_cfg_with_customized_provider.is_ok()); // case 5: - let mut invalid_empty_models_cfg = Config::default().with_provider("OPENAI"); - assert!(invalid_empty_models_cfg.finalize().is_err()); + let invalid_empty_models_cfg = Config::builder().build(); + assert!(invalid_empty_models_cfg.is_err()); + + // case 6: + print!("validating invalid empty model id config"); + let invalid_empty_model_id_cfg = ModelConfig::builder().build(); + assert!(invalid_empty_model_id_cfg.is_err()); + } + + #[test] + fn test_populate_config() { + let mut valid_cfg = Config::builder() + .temperature(0.5) + .max_output_tokens(1500) + .model( + ModelConfig::builder() + .id("model-1".to_string()) + .build() + .unwrap(), + ) + .build(); + valid_cfg.as_mut().unwrap().populate(); + + assert!(valid_cfg.is_ok()); + assert!(valid_cfg.as_ref().unwrap().models.len() == 1); + assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5)); + assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500)); + assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string())); + assert!( + valid_cfg.as_ref().unwrap().models[0].base_url + == Some("https://api.openai.com/v1".to_string()) + ); + assert!(valid_cfg.as_ref().unwrap().models[0].weight == -1); + + let mut valid_specified_cfg = Config::builder() + .provider("DEEPINFRA".to_string()) + .base_url("http://custom-api.ai".to_string()) + .model( + ModelConfig::builder() + .id("model-2".to_string()) + .build() + .unwrap(), + ) + .build(); + valid_specified_cfg.as_mut().unwrap().populate(); + + assert!(valid_specified_cfg.is_ok()); + assert!(valid_specified_cfg.as_ref().unwrap().provider == "DEEPINFRA".to_string()); + assert!( + valid_specified_cfg.as_ref().unwrap().models[0].base_url + == Some("http://custom-api.ai".to_string()) + ); } } diff --git a/src/lib.rs b/src/lib.rs index 4524963..6d47760 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ mod router { pub mod router; mod weight; } -mod config; mod client { pub mod client; } @@ -12,5 +11,5 @@ mod provider { pub mod provider; } +pub mod config; pub use crate::client::client::Client; -pub use crate::config::Config; diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 5ef82fc..601013a 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -50,6 +50,10 @@ impl OpenAIProvider { #[async_trait] impl Provider for OpenAIProvider { + fn name(&self) -> &'static str { + "OpenAIProvider" + } + async fn create_response(&self, request: ResponseRequest) -> Result { let client = self.client.as_ref().unwrap(); client.responses().create(request).await diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 76fa461..0f4eb44 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -2,21 +2,82 @@ use async_openai::error::OpenAIError; use async_openai::types::responses::{CreateResponse as OpenAIRequest, Response as OpenAIResponse}; use async_trait::async_trait; -use crate::config::{ModelConfig, ProviderName}; +use crate::config::ModelConfig; use crate::provider::openai::OpenAIProvider; pub type ResponseRequest = OpenAIRequest; pub type ResponseResult = OpenAIResponse; pub type APIError = OpenAIError; -pub fn build_provider(provider: &ProviderName, config: &ModelConfig) -> Box { - match provider.as_str() { - "openai" => Box::new(OpenAIProvider::new(config).build()), +pub fn construct_provider(config: &ModelConfig) -> Box { + let provider = config.provider.as_ref().unwrap(); + match provider.to_uppercase().as_ref() { + "OPENAI" => Box::new(OpenAIProvider::new(config).build()), _ => panic!("Unsupported provider: {}", provider), } } #[async_trait] pub trait Provider: Send + Sync { + fn name(&self) -> &'static str; async fn create_response(&self, request: ResponseRequest) -> Result; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_provider_construction() { + struct TestCase { + name: &'static str, + config: ModelConfig, + expect_provider_type: &'static str, + } + + let cases = vec![ + TestCase { + name: "OpenAI Provider", + config: ModelConfig::builder() + .id("test-model".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + expect_provider_type: "OpenAIProvider", + }, + TestCase { + name: "Unsupported Provider", + config: ModelConfig::builder() + .id("test-model".to_string()) + .provider(Some("unsupported".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + expect_provider_type: "", + }, + ]; + + for case in cases { + if case.expect_provider_type.is_empty() { + let result = std::panic::catch_unwind(|| { + construct_provider(&case.config); + }); + assert!( + result.is_err(), + "Test case '{}' did not panic as expected", + case.name + ); + } else { + let provider = construct_provider(&case.config); + assert!( + provider.name() == case.expect_provider_type, + "Test case '{}': expected provider type '{}', got '{}'", + case.name, + case.expect_provider_type, + provider.name() + ); + } + } + } +} diff --git a/src/router/random.rs b/src/router/random.rs index 0d0966b..01d7d7a 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -15,9 +15,33 @@ impl RandomRouter { } impl Router for RandomRouter { - fn sample(&self, _input: &ResponseRequest) -> &ModelId { + fn name(&self) -> &'static str { + "RandomRouter" + } + + fn sample(&self, _input: &ResponseRequest) -> ModelId { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_ids.len()); - &self.model_ids[idx] + self.model_ids[idx].clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_random_router_sampling() { + let model_ids = vec!["model_a".to_string(), "model_b".to_string()]; + let router = RandomRouter::new(model_ids.clone()); + let mut counts = std::collections::HashMap::new(); + for _ in 0..1000 { + let sampled_id = router.sample(&ResponseRequest::default()); + *counts.entry(sampled_id.clone()).or_insert(0) += 1; + } + assert!(counts.len() == model_ids.len()); + for count in counts.values() { + assert!(*count > 0); + } } } diff --git a/src/router/router.rs b/src/router/router.rs index 17c3708..9aca04b 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -6,7 +6,7 @@ use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; use crate::router::weight::WeightedRouter; -pub fn build_router(mode: RoutingMode, models: Vec) -> Box { +pub fn construct_router(mode: RoutingMode, models: Vec) -> Box { let model_ids: Vec = models.iter().map(|m| m.id.clone()).collect(); match mode { RoutingMode::Random => Box::new(RandomRouter::new(model_ids)), @@ -15,7 +15,8 @@ pub fn build_router(mode: RoutingMode, models: Vec) -> Box &ModelId; + fn name(&self) -> &'static str; + fn sample(&self, input: &ResponseRequest) -> ModelId; } pub struct RouterTracker { @@ -33,3 +34,31 @@ impl RouterTracker { } } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_router_construction() { + let model_configs = vec![ + ModelConfig::builder() + .id("model_a".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ModelConfig::builder() + .id("model_b".to_string()) + .provider(Some("openai".to_string())) + .base_url(Some("https://api.openai.com/v1".to_string())) + .build() + .unwrap(), + ]; + + let random_router = construct_router(RoutingMode::Random, model_configs.clone()); + assert_eq!(random_router.name(), "RandomRouter"); + + let weighted_router = construct_router(RoutingMode::Weighted, model_configs.clone()); + assert_eq!(weighted_router.name(), "WeightedRouter"); + } +} diff --git a/src/router/weight.rs b/src/router/weight.rs index a064b79..5f1ed1a 100644 --- a/src/router/weight.rs +++ b/src/router/weight.rs @@ -12,8 +12,12 @@ impl WeightedRouter { } impl Router for WeightedRouter { - fn sample(&self, _input: &ResponseRequest) -> &ModelId { + fn name(&self) -> &'static str { + "WeightedRouter" + } + + fn sample(&self, _input: &ResponseRequest) -> ModelId { // TODO: Implement weighted sampling logic - return &self.model_ids[0]; + return self.model_ids[0].clone(); } }