From 7013f49a630c15d0ce07010cbef03a075efa800b Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 18:10:58 +0800 Subject: [PATCH 1/3] rename router_tracker to router_stats Signed-off-by: kerthcet --- src/client/client.rs | 27 ++++----------------------- src/router/router.rs | 12 ++++-------- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 82f7a10..29c74cf 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -5,9 +5,9 @@ use crate::provider::provider; use crate::router::router; pub struct Client { - router_tracker: Option, - router: Box, providers: HashMap>, + router: Box, + router_stats: router::RouterStats, } impl Client { @@ -22,15 +22,9 @@ impl Client { .collect(); Self { - router_tracker: None, providers: providers, router: router::construct_router(cfg.routing_mode, cfg.models), - } - } - - pub fn enable_router_tracker(&mut self) { - if self.router_tracker.is_none() { - self.router_tracker = Some(router::RouterTracker::new()); + router_stats: router::RouterStats::default(), } } @@ -53,7 +47,6 @@ mod tests { name: &'static str, config: Config, expected_router_name: &'static str, - enabled_tracker: bool, } let cases = vec![ @@ -69,7 +62,6 @@ mod tests { .build() .unwrap(), expected_router_name: "RandomRouter", - enabled_tracker: false, }, TestCase { name: "weighted router", @@ -94,7 +86,6 @@ mod tests { .build() .unwrap(), expected_router_name: "WeightedRouter", - enabled_tracker: false, }, TestCase { name: "router tracker enabled", @@ -116,27 +107,17 @@ mod tests { .build() .unwrap(), expected_router_name: "RandomRouter", - enabled_tracker: true, }, ]; for case in cases { - let mut client = Client::new(case.config.clone()); - if case.enabled_tracker { - client.enable_router_tracker(); - } + let client = Client::new(case.config.clone()); assert_eq!( client.router.name(), case.expected_router_name, "Test case '{}' failed", case.name ); - assert_eq!( - client.router_tracker.is_some(), - case.enabled_tracker, - "Test case '{}' failed on router tracker state", - case.name - ); assert_eq!( client.providers.len(), case.config.models.len(), diff --git a/src/router/router.rs b/src/router/router.rs index 9aca04b..74f8848 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -19,18 +19,14 @@ pub trait Router { fn sample(&self, input: &ResponseRequest) -> ModelId; } -pub struct RouterTracker { +pub struct RouterStats { total_requests: HashMap, - avg_latencies: HashMap, - total_tokens: HashMap, } -impl RouterTracker { - pub fn new() -> Self { - RouterTracker { +impl RouterStats { + pub fn default() -> Self { + RouterStats { total_requests: HashMap::new(), - avg_latencies: HashMap::new(), - total_tokens: HashMap::new(), } } } From 47691024ebdc4c06f70c2a05ea3d82a1efb7737b Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 22 Dec 2025 23:57:35 +0800 Subject: [PATCH 2/3] add support for WRR algo for sampling Signed-off-by: kerthcet --- src/client/client.rs | 12 +++--- src/config.rs | 4 +- src/lib.rs | 1 + src/router/random.rs | 34 +++++++++++----- src/router/router.rs | 36 ++++++++--------- src/router/stats.rs | 24 ++++++++++++ src/router/weight.rs | 93 +++++++++++++++++++++++++++++++++++++++----- 7 files changed, 158 insertions(+), 46 deletions(-) create mode 100644 src/router/stats.rs diff --git a/src/client/client.rs b/src/client/client.rs index 29c74cf..196b833 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; -use crate::config::{Config, ModelConfig, ModelId, RoutingMode}; +use crate::config::{Config, ModelId}; use crate::provider::provider; use crate::router::router; pub struct Client { providers: HashMap>, router: Box, - router_stats: router::RouterStats, } impl Client { @@ -24,12 +23,11 @@ impl Client { Self { providers: providers, router: router::construct_router(cfg.routing_mode, cfg.models), - router_stats: router::RouterStats::default(), } } pub async fn create_response( - &self, + &mut self, request: provider::ResponseRequest, ) -> Result { let model_id = self.router.sample(&request); @@ -41,6 +39,8 @@ impl Client { #[cfg(test)] mod tests { use super::*; + use crate::config::{Config, ModelConfig, RoutingMode}; + #[test] fn test_client_new() { struct TestCase { @@ -64,9 +64,9 @@ mod tests { expected_router_name: "RandomRouter", }, TestCase { - name: "weighted router", + name: "weighted round-robin router", config: Config::builder() - .routing_mode(RoutingMode::Weighted) + .routing_mode(RoutingMode::WRR) .models(vec![ crate::config::ModelConfig::builder() .id("model_a".to_string()) diff --git a/src/config.rs b/src/config.rs index d09fb61..521e3e4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,7 @@ lazy_static! { #[derive(Debug, Clone, PartialEq)] pub enum RoutingMode { Random, - Weighted, + WRR, // WeightedRoundRobin, } // ------------------ Model Config ------------------ @@ -131,7 +131,7 @@ impl ConfigBuilder { for model in self.models.as_ref().unwrap() { if self.routing_mode.is_some() - && self.routing_mode.as_ref().unwrap() == &RoutingMode::Weighted + && self.routing_mode.as_ref().unwrap() == &RoutingMode::WRR && model.weight <= 0 { return Err(format!( diff --git a/src/lib.rs b/src/lib.rs index 6d47760..3daf995 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod router { mod random; pub mod router; + pub mod stats; mod weight; } mod client { diff --git a/src/router/random.rs b/src/router/random.rs index 01d7d7a..c774d50 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -2,15 +2,15 @@ use rand::Rng; use crate::config::ModelId; use crate::provider::provider::ResponseRequest; -use crate::router::router::Router; +use crate::router::router::{ModelInfo, Router}; pub struct RandomRouter { - pub model_ids: Vec, + pub model_infos: Vec, } impl RandomRouter { - pub fn new(model_ids: Vec) -> Self { - Self { model_ids } + pub fn new(model_infos: Vec) -> Self { + Self { model_infos } } } @@ -19,10 +19,10 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&self, _input: &ResponseRequest) -> ModelId { + fn sample(&mut self, _input: &ResponseRequest) -> ModelId { let mut rng = rand::rng(); - let idx = rng.random_range(0..self.model_ids.len()); - self.model_ids[idx].clone() + let idx = rng.random_range(0..self.model_infos.len()); + self.model_infos[idx].id.clone() } } @@ -32,14 +32,28 @@ mod tests { #[test] fn test_random_router_sampling() { - let model_ids = vec!["model_a".to_string(), "model_b".to_string()]; - let router = RandomRouter::new(model_ids.clone()); + let model_infos = vec![ + ModelInfo { + id: "model_x".to_string(), + weight: 1, + }, + ModelInfo { + id: "model_y".to_string(), + weight: 2, + }, + ModelInfo { + id: "model_z".to_string(), + weight: 3, + }, + ]; + let mut router = RandomRouter::new(model_infos.clone()); let mut counts = std::collections::HashMap::new(); + for _ in 0..1000 { let sampled_id = router.sample(&ResponseRequest::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } - assert!(counts.len() == model_ids.len()); + assert!(counts.len() == model_infos.len()); for count in counts.values() { assert!(*count > 0); } diff --git a/src/router/router.rs b/src/router/router.rs index 74f8848..35eb624 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -4,31 +4,31 @@ use std::sync::atomic::AtomicUsize; use crate::config::{ModelConfig, ModelId, RoutingMode}; use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; -use crate::router::weight::WeightedRouter; +use crate::router::weight::WeightedRoundRobinRouter; + +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub id: ModelId, + pub weight: i32, +} pub fn construct_router(mode: RoutingMode, models: Vec) -> Box { - let model_ids: Vec = models.iter().map(|m| m.id.clone()).collect(); + let model_infos: Vec = models + .iter() + .map(|m| ModelInfo { + id: m.id.clone(), + weight: m.weight.clone(), + }) + .collect(); match mode { - RoutingMode::Random => Box::new(RandomRouter::new(model_ids)), - RoutingMode::Weighted => Box::new(WeightedRouter::new(model_ids)), + RoutingMode::Random => Box::new(RandomRouter::new(model_infos)), + RoutingMode::WRR => Box::new(WeightedRoundRobinRouter::new(model_infos)), } } pub trait Router { fn name(&self) -> &'static str; - fn sample(&self, input: &ResponseRequest) -> ModelId; -} - -pub struct RouterStats { - total_requests: HashMap, -} - -impl RouterStats { - pub fn default() -> Self { - RouterStats { - total_requests: HashMap::new(), - } - } + fn sample(&mut self, input: &ResponseRequest) -> ModelId; } #[cfg(test)] @@ -54,7 +54,7 @@ mod tests { let random_router = construct_router(RoutingMode::Random, model_configs.clone()); assert_eq!(random_router.name(), "RandomRouter"); - let weighted_router = construct_router(RoutingMode::Weighted, model_configs.clone()); + let weighted_router = construct_router(RoutingMode::WRR, model_configs.clone()); assert_eq!(weighted_router.name(), "WeightedRouter"); } } diff --git a/src/router/stats.rs b/src/router/stats.rs new file mode 100644 index 0000000..7b4251a --- /dev/null +++ b/src/router/stats.rs @@ -0,0 +1,24 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::config::ModelId; + +pub struct RouterStats { + requests_per_model: HashMap, +} + +impl RouterStats { + pub fn default() -> Self { + RouterStats { + requests_per_model: HashMap::new(), + } + } + + pub fn increment_request(&mut self, model_id: &ModelId) -> usize { + let counter = self + .requests_per_model + .entry(model_id.clone()) + .or_insert_with(|| AtomicUsize::new(0)); + counter.fetch_add(1, Ordering::Relaxed) + } +} diff --git a/src/router/weight.rs b/src/router/weight.rs index 5f1ed1a..595f6ba 100644 --- a/src/router/weight.rs +++ b/src/router/weight.rs @@ -1,23 +1,96 @@ -use super::router::Router; +use crate::router::router::{ModelInfo, Router}; use crate::{config::ModelId, provider::provider::ResponseRequest}; -pub struct WeightedRouter { - pub model_ids: Vec, +pub struct WeightedRoundRobinRouter { + total_weight: i32, + model_infos: Vec, + // current_weight is ordered by model_infos index. + current_weights: Vec, } -impl WeightedRouter { - pub fn new(model_ids: Vec) -> Self { - Self { model_ids } +impl WeightedRoundRobinRouter { + pub fn new(model_infos: Vec) -> Self { + let total_weight = model_infos.iter().map(|m| m.weight).sum(); + let length = model_infos.len(); + + Self { + model_infos: model_infos, + total_weight: total_weight, + current_weights: vec![0; length], + } } } -impl Router for WeightedRouter { +impl Router for WeightedRoundRobinRouter { fn name(&self) -> &'static str { "WeightedRouter" } - fn sample(&self, _input: &ResponseRequest) -> ModelId { - // TODO: Implement weighted sampling logic - return self.model_ids[0].clone(); + // Use Smooth Weighted Round Robin Algorithm. + fn sample(&mut self, _input: &ResponseRequest) -> ModelId { + // return early if only one model. + if self.model_infos.len() == 1 { + return self.model_infos[0].id.clone(); + } + + self.current_weights + .iter_mut() + .enumerate() + .for_each(|(i, weight)| { + *weight += self.model_infos[i].weight; + }); + + let mut max_index = 0; + for i in 1..self.current_weights.len() { + if self.current_weights[i] > self.current_weights[max_index] { + max_index = i; + } + } + + self.current_weights[max_index] -= self.total_weight; + self.model_infos[max_index].id.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_weighted_round_robin_sampling() { + let model_infos = vec![ + ModelInfo { + id: "model_x".to_string(), + weight: 1, + }, + ModelInfo { + id: "model_y".to_string(), + weight: 3, + }, + ModelInfo { + id: "model_z".to_string(), + weight: 6, + }, + ]; + let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); + let mut counts = HashMap::new(); + for _ in 0..1000 { + let sampled_id = wrr.sample(&ResponseRequest::default()); + *counts.entry(sampled_id.clone()).or_insert(0) += 1; + } + assert!(counts.len() == model_infos.len()); + // Check approximate distribution. + let total_counts: usize = counts.values().sum(); + assert!(total_counts == 1000); + let model_x_counts = *counts.get("model_x").unwrap_or(&0); + let model_y_counts = *counts.get("model_y").unwrap_or(&0); + let model_z_counts = *counts.get("model_z").unwrap_or(&0); + let model_x_ratio = model_x_counts as f64 / total_counts as f64; + let model_y_ratio = model_y_counts as f64 / total_counts as f64; + let model_z_ratio = model_z_counts as f64 / total_counts as f64; + assert!((model_x_ratio - 0.1).abs() < 0.05); + assert!((model_y_ratio - 0.3).abs() < 0.05); + assert!((model_z_ratio - 0.6).abs() < 0.05); } } From 2b5e897e86005569710ce21c07262583e8c6e84e Mon Sep 17 00:00:00 2001 From: kerthcet Date: Tue, 23 Dec 2025 00:02:10 +0800 Subject: [PATCH 3/3] Rename weightedRouter to wrrRouter Signed-off-by: kerthcet --- src/client/client.rs | 2 +- src/lib.rs | 2 +- src/router/router.rs | 4 ++-- src/router/{weight.rs => wrr.rs} | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename src/router/{weight.rs => wrr.rs} (98%) diff --git a/src/client/client.rs b/src/client/client.rs index 196b833..a7b9123 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -85,7 +85,7 @@ mod tests { ]) .build() .unwrap(), - expected_router_name: "WeightedRouter", + expected_router_name: "WeightedRoundRobinRouter", }, TestCase { name: "router tracker enabled", diff --git a/src/lib.rs b/src/lib.rs index 3daf995..3f0f656 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ mod router { mod random; pub mod router; pub mod stats; - mod weight; + mod wrr; } mod client { pub mod client; diff --git a/src/router/router.rs b/src/router/router.rs index 35eb624..ecec90b 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -4,7 +4,7 @@ use std::sync::atomic::AtomicUsize; use crate::config::{ModelConfig, ModelId, RoutingMode}; use crate::provider::provider::ResponseRequest; use crate::router::random::RandomRouter; -use crate::router::weight::WeightedRoundRobinRouter; +use crate::router::wrr::WeightedRoundRobinRouter; #[derive(Debug, Clone)] pub struct ModelInfo { @@ -55,6 +55,6 @@ mod tests { assert_eq!(random_router.name(), "RandomRouter"); let weighted_router = construct_router(RoutingMode::WRR, model_configs.clone()); - assert_eq!(weighted_router.name(), "WeightedRouter"); + assert_eq!(weighted_router.name(), "WeightedRoundRobinRouter"); } } diff --git a/src/router/weight.rs b/src/router/wrr.rs similarity index 98% rename from src/router/weight.rs rename to src/router/wrr.rs index 595f6ba..b1ea9c7 100644 --- a/src/router/weight.rs +++ b/src/router/wrr.rs @@ -23,7 +23,7 @@ impl WeightedRoundRobinRouter { impl Router for WeightedRoundRobinRouter { fn name(&self) -> &'static str { - "WeightedRouter" + "WeightedRoundRobinRouter" } // Use Smooth Weighted Round Robin Algorithm.