diff --git a/src/client/client.rs b/src/client/client.rs index 82f7a10..a7b9123 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; -use crate::config::{Config, ModelConfig, ModelId, RoutingMode}; +use crate::config::{Config, ModelId}; use crate::provider::provider; use crate::router::router; pub struct Client { - router_tracker: Option, - router: Box, providers: HashMap>, + router: Box, } impl Client { @@ -22,20 +21,13 @@ impl Client { .collect(); Self { - router_tracker: None, providers: providers, router: router::construct_router(cfg.routing_mode, cfg.models), } } - pub fn enable_router_tracker(&mut self) { - if self.router_tracker.is_none() { - self.router_tracker = Some(router::RouterTracker::new()); - } - } - pub async fn create_response( - &self, + &mut self, request: provider::ResponseRequest, ) -> Result { let model_id = self.router.sample(&request); @@ -47,13 +39,14 @@ impl Client { #[cfg(test)] mod tests { use super::*; + use crate::config::{Config, ModelConfig, RoutingMode}; + #[test] fn test_client_new() { struct TestCase { name: &'static str, config: Config, expected_router_name: &'static str, - enabled_tracker: bool, } let cases = vec![ @@ -69,12 +62,11 @@ mod tests { .build() .unwrap(), expected_router_name: "RandomRouter", - enabled_tracker: false, }, TestCase { - name: "weighted router", + name: "weighted round-robin router", config: Config::builder() - .routing_mode(RoutingMode::Weighted) + .routing_mode(RoutingMode::WRR) .models(vec![ crate::config::ModelConfig::builder() .id("model_a".to_string()) @@ -93,8 +85,7 @@ mod tests { ]) .build() .unwrap(), - expected_router_name: "WeightedRouter", - enabled_tracker: false, + expected_router_name: "WeightedRoundRobinRouter", }, TestCase { name: "router tracker enabled", @@ -116,27 +107,17 @@ mod tests { .build() .unwrap(), expected_router_name: "RandomRouter", - enabled_tracker: true, }, ]; for case in cases { - let mut client = Client::new(case.config.clone()); - if case.enabled_tracker { - client.enable_router_tracker(); - } + let client = Client::new(case.config.clone()); assert_eq!( client.router.name(), case.expected_router_name, "Test case '{}' failed", case.name ); - assert_eq!( - client.router_tracker.is_some(), - case.enabled_tracker, - "Test case '{}' failed on router tracker state", - case.name - ); assert_eq!( client.providers.len(), case.config.models.len(), diff --git a/src/config.rs b/src/config.rs index d09fb61..521e3e4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ lazy_static! { #[derive(Debug, Clone, PartialEq)] pub enum RoutingMode { Random, - Weighted, + WRR, // WeightedRoundRobin, } // ------------------ Model Config ------------------ @@ -131,7 +131,7 @@ impl ConfigBuilder { for model in self.models.as_ref().unwrap() { if self.routing_mode.is_some() - && self.routing_mode.as_ref().unwrap() == &RoutingMode::Weighted + && self.routing_mode.as_ref().unwrap() == &RoutingMode::WRR && model.weight <= 0 { return Err(format!( diff --git a/src/lib.rs b/src/lib.rs index 6d47760..3f0f656 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ mod router { mod random; pub mod router; - mod weight; + pub mod stats; + mod wrr; } mod client { pub mod client; diff --git a/src/router/random.rs b/src/router/random.rs index 01d7d7a..c774d50 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -2,15 +2,15 @@ use rand::Rng; use crate::config::ModelId; use crate::provider::provider::ResponseRequest; -use crate::router::router::Router; +use crate::router::router::{ModelInfo, Router}; pub struct RandomRouter { - pub model_ids: Vec, + pub model_infos: Vec, } impl RandomRouter { - pub fn new(model_ids: Vec) -> Self { - Self { model_ids } + pub fn new(model_infos: Vec) -> Self { + Self { model_infos } } } @@ -19,10 +19,10 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&self, _input: &ResponseRequest) -> ModelId { + fn sample(&mut self, _input: &ResponseRequest) -> ModelId { let mut rng = rand::rng(); - let idx = rng.random_range(0..self.model_ids.len()); - self.model_ids[idx].clone() + let idx = rng.random_range(0..self.model_infos.len()); + self.model_infos[idx].id.clone() } } @@ -32,14 +32,28 @@ mod tests { #[test] fn test_random_router_sampling() { - let model_ids = vec!["model_a".to_string(), "model_b".to_string()]; - let router = RandomRouter::new(model_ids.clone()); + let model_infos = vec![ + ModelInfo { + id: "model_x".to_string(), + weight: 1, + }, + ModelInfo { + id: "model_y".to_string(), + weight: 2, + }, + ModelInfo { + id: "model_z".to_string(), + weight: 3, + }, + ]; + let mut router = RandomRouter::new(model_infos.clone()); let mut counts = std::collections::HashMap::new(); + for _ in 0..1000 { let sampled_id = router.sample(&ResponseRequest::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } - assert!(counts.len() == model_ids.len()); + assert!(counts.len() == model_infos.len()); for count in counts.values() { assert!(*count > 0); } diff --git a/src/router/router.rs b/src/router/router.rs index 9aca04b..ecec90b 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -4,35 +4,31 @@ use std::sync::atomic::AtomicUsize; use crate::config::{ModelConfig, ModelId, RoutingMode}; use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; -use crate::router::weight::WeightedRouter; +use crate::router::wrr::WeightedRoundRobinRouter; + +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub id: ModelId, + pub weight: i32, +} pub fn construct_router(mode: RoutingMode, models: Vec) -> Box { - let model_ids: Vec = models.iter().map(|m| m.id.clone()).collect(); + let model_infos: Vec = models + .iter() + .map(|m| ModelInfo { + id: m.id.clone(), + weight: m.weight.clone(), + }) + .collect(); match mode { - RoutingMode::Random => Box::new(RandomRouter::new(model_ids)), - RoutingMode::Weighted => Box::new(WeightedRouter::new(model_ids)), + RoutingMode::Random => Box::new(RandomRouter::new(model_infos)), + RoutingMode::WRR => Box::new(WeightedRoundRobinRouter::new(model_infos)), } } pub trait Router { fn name(&self) -> &'static str; - fn sample(&self, input: &ResponseRequest) -> ModelId; -} - -pub struct RouterTracker { - total_requests: HashMap, - avg_latencies: HashMap, - total_tokens: HashMap, -} - -impl RouterTracker { - pub fn new() -> Self { - RouterTracker { - total_requests: HashMap::new(), - avg_latencies: HashMap::new(), - total_tokens: HashMap::new(), - } - } + fn sample(&mut self, input: &ResponseRequest) -> ModelId; } #[cfg(test)] @@ -58,7 +54,7 @@ mod tests { let random_router = construct_router(RoutingMode::Random, model_configs.clone()); assert_eq!(random_router.name(), "RandomRouter"); - let weighted_router = construct_router(RoutingMode::Weighted, model_configs.clone()); - assert_eq!(weighted_router.name(), "WeightedRouter"); + let weighted_router = construct_router(RoutingMode::WRR, model_configs.clone()); + assert_eq!(weighted_router.name(), "WeightedRoundRobinRouter"); } } diff --git a/src/router/stats.rs b/src/router/stats.rs new file mode 100644 index 0000000..7b4251a --- /dev/null +++ b/src/router/stats.rs @@ -0,0 +1,24 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::config::ModelId; + +pub struct RouterStats { + requests_per_model: HashMap, +} + +impl RouterStats { + pub fn default() -> Self { + RouterStats { + requests_per_model: HashMap::new(), + } + } + + pub fn increment_request(&mut self, model_id: &ModelId) -> usize { + let counter = self + .requests_per_model + .entry(model_id.clone()) + .or_insert_with(|| AtomicUsize::new(0)); + counter.fetch_add(1, Ordering::Relaxed) + } +} diff --git a/src/router/weight.rs b/src/router/weight.rs deleted file mode 100644 index 5f1ed1a..0000000 --- a/src/router/weight.rs +++ /dev/null @@ -1,23 +0,0 @@ -use super::router::Router; -use crate::{config::ModelId, provider::provider::ResponseRequest}; - -pub struct WeightedRouter { - pub model_ids: Vec, -} - -impl WeightedRouter { - pub fn new(model_ids: Vec) -> Self { - Self { model_ids } - } -} - -impl Router for WeightedRouter { - fn name(&self) -> &'static str { - "WeightedRouter" - } - - fn sample(&self, _input: &ResponseRequest) -> ModelId { - // TODO: Implement weighted sampling logic - return self.model_ids[0].clone(); - } -} diff --git a/src/router/wrr.rs b/src/router/wrr.rs new file mode 100644 index 0000000..b1ea9c7 --- /dev/null +++ b/src/router/wrr.rs @@ -0,0 +1,96 @@ +use crate::router::router::{ModelInfo, Router}; +use crate::{config::ModelId, provider::provider::ResponseRequest}; + +pub struct WeightedRoundRobinRouter { + total_weight: i32, + model_infos: Vec, + // current_weight is ordered by model_infos index. + current_weights: Vec, +} + +impl WeightedRoundRobinRouter { + pub fn new(model_infos: Vec) -> Self { + let total_weight = model_infos.iter().map(|m| m.weight).sum(); + let length = model_infos.len(); + + Self { + model_infos: model_infos, + total_weight: total_weight, + current_weights: vec![0; length], + } + } +} + +impl Router for WeightedRoundRobinRouter { + fn name(&self) -> &'static str { + "WeightedRoundRobinRouter" + } + + // Use Smooth Weighted Round Robin Algorithm. + fn sample(&mut self, _input: &ResponseRequest) -> ModelId { + // return early if only one model. + if self.model_infos.len() == 1 { + return self.model_infos[0].id.clone(); + } + + self.current_weights + .iter_mut() + .enumerate() + .for_each(|(i, weight)| { + *weight += self.model_infos[i].weight; + }); + + let mut max_index = 0; + for i in 1..self.current_weights.len() { + if self.current_weights[i] > self.current_weights[max_index] { + max_index = i; + } + } + + self.current_weights[max_index] -= self.total_weight; + self.model_infos[max_index].id.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_weighted_round_robin_sampling() { + let model_infos = vec![ + ModelInfo { + id: "model_x".to_string(), + weight: 1, + }, + ModelInfo { + id: "model_y".to_string(), + weight: 3, + }, + ModelInfo { + id: "model_z".to_string(), + weight: 6, + }, + ]; + let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); + let mut counts = HashMap::new(); + for _ in 0..1000 { + let sampled_id = wrr.sample(&ResponseRequest::default()); + *counts.entry(sampled_id.clone()).or_insert(0) += 1; + } + assert!(counts.len() == model_infos.len()); + // Check approximate distribution. + let total_counts: usize = counts.values().sum(); + assert!(total_counts == 1000); + let model_x_counts = *counts.get("model_x").unwrap_or(&0); + let model_y_counts = *counts.get("model_y").unwrap_or(&0); + let model_z_counts = *counts.get("model_z").unwrap_or(&0); + let model_x_ratio = model_x_counts as f64 / total_counts as f64; + let model_y_ratio = model_y_counts as f64 / total_counts as f64; + let model_z_ratio = model_z_counts as f64 / total_counts as f64; + assert!((model_x_ratio - 0.1).abs() < 0.05); + assert!((model_y_ratio - 0.3).abs() < 0.05); + assert!((model_z_ratio - 0.6).abs() < 0.05); + } +}