Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 20 additions & 13 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,28 @@ pub struct Client {
}

impl Client {
pub fn new(config: &Config) -> Self {
pub fn new(config: Config) -> Self {
let mut cfg = config;
cfg.finalize().expect("Invalid configuration");

let providers = cfg
.models
.iter()
.map(|m| {
let provider = m
.provider
.as_ref()
.expect("Model provider must be specified");

(m.id.clone(), provider::build_provider(provider, m))
})
.collect();

Self {
config: config.clone(),
config: cfg.clone(),
router_tracker: None,
providers: HashMap::new(),
router: router::build_router(&config.routing_mode, &config.models),
providers: providers,
router: router::build_router(cfg.routing_mode, cfg.models),
}
}

Expand All @@ -43,15 +59,6 @@ impl Client {
}
}

pub fn build(&mut self) {
self.config.models.iter().for_each(|m| {
self.providers.insert(
m.id.clone(),
provider::build_provider(&m.provider.as_ref().unwrap(), m),
);
});
}

pub async fn create_response(
&self,
request: provider::ResponseRequest,
Expand Down
131 changes: 57 additions & 74 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct ModelConfig {
}

impl ModelConfig {
pub fn new(id: ModelId) -> Self {
fn new(id: ModelId) -> Self {
Self {
base_url: None,
provider: None,
Expand All @@ -53,47 +53,43 @@ impl ModelConfig {
}
}

pub fn base_url(mut self, url: &str) -> Self {
pub fn with_base_url(mut self, url: &str) -> Self {
self.base_url = Some(url.to_string());
self
}

pub fn provider(mut self, provider: &str) -> Self {
pub fn with_provider(mut self, provider: &str) -> Self {
self.provider = Some(provider.to_string());
self
}

pub fn temperature(mut self, temperature: f32) -> Self {
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}

pub fn max_output_tokens(mut self, max_output_tokens: usize) -> Self {
pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self {
self.max_output_tokens = Some(max_output_tokens);
self
}

pub fn weight(mut self, weight: i32) -> Self {
pub fn with_weight(mut self, weight: i32) -> Self {
self.weight = weight;
self
}

pub fn build(self) -> Result<Self, String> {
Ok(self)
}
}

// ------------------ Main Config ------------------
#[derive(Debug, Clone)]
pub struct Config {
// global configs for models, will be overridden by model-specific configs
pub base_url: Option<String>,
pub provider: ProviderName, // "AMRS" by default
pub temperature: f32, // 0.8 by default
pub max_output_tokens: usize, // 1024 by default
pub(crate) base_url: Option<String>,
pub(crate) provider: ProviderName, // "AMRS" by default
pub(crate) temperature: f32, // 0.8 by default
pub(crate) max_output_tokens: usize, // 1024 by default

pub routing_mode: RoutingMode, // Random by default
pub models: Vec<ModelConfig>,
pub(crate) routing_mode: RoutingMode, // Random by default
pub(crate) models: Vec<ModelConfig>,
}

impl Default for Config {
Expand All @@ -110,48 +106,43 @@ impl Default for Config {
}

impl Config {
pub fn new() -> Self {
let cfg = Config::default();
cfg
}

pub fn base_url(mut self, url: &str) -> Self {
pub fn with_base_url(mut self, url: &str) -> Self {
self.base_url = Some(url.to_string());
self
}

pub fn provider(mut self, provider: &str) -> Self {
pub fn with_provider(mut self, provider: &str) -> Self {
self.provider = provider.to_string();
self
}

pub fn temperature(mut self, temperature: f32) -> Self {
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}

pub fn max_output_tokens(mut self, max_output_tokens: usize) -> Self {
pub fn with_max_output_tokens(mut self, max_output_tokens: usize) -> Self {
self.max_output_tokens = max_output_tokens;
self
}

pub fn routing_mode(mut self, mode: RoutingMode) -> Self {
pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
self.routing_mode = mode;
self
}

pub fn add_model(mut self, model: ModelConfig) -> Self {
pub fn with_model(mut self, model: ModelConfig) -> Self {
self.models.push(model);
self
}

pub fn build(mut self) -> Result<Self, String> {
self.set_defaults();
self.validate()?;
Ok(self)
pub fn finalize(&mut self) -> Result<(), String> {
self.post_defaults();
self.validate_model_config()?;
Ok(())
}

fn set_defaults(&mut self) {
fn post_defaults(&mut self) {
for model in &mut self.models {
let model_url_exist = model.base_url.is_some();

Expand All @@ -177,11 +168,6 @@ impl Config {
}
}

fn validate(&self) -> Result<(), String> {
self.validate_model_config()?;
Ok(())
}

fn validate_model_config(&self) -> Result<(), String> {
if self.models.is_empty() {
return Err("At least one model must be configured.".to_string());
Expand Down Expand Up @@ -244,61 +230,58 @@ mod tests {
from_filename(".env.test").ok();

// case 1:
let valid_simplest_models_cfg = Config::new()
.provider("OPENAI")
.add_model(ModelConfig::new("gpt-4".to_string()).build().unwrap());
let res = valid_simplest_models_cfg.build();
let mut valid_simplest_models_cfg = Config::default()
.with_provider("OPENAI")
.with_model(ModelConfig::new("gpt-4".to_string()));
let res = valid_simplest_models_cfg.finalize();
assert!(res.clone().is_ok());
assert!(
res.clone().unwrap().models[0].base_url
valid_simplest_models_cfg.models[0].base_url
== Some("https://api.openai.com/v1".to_string())
);
assert!(res.clone().unwrap().models[0].provider == Some("OPENAI".to_string()));
assert!(res.clone().unwrap().models[0].temperature == Some(0.8));
assert!(res.clone().unwrap().models[0].max_output_tokens == Some(1024));
assert!(res.clone().unwrap().models[0].weight == -1);
assert!(valid_simplest_models_cfg.models[0].provider == Some("OPENAI".to_string()));
assert!(valid_simplest_models_cfg.models[0].temperature == Some(0.8));
assert!(valid_simplest_models_cfg.models[0].max_output_tokens == Some(1024));
assert!(valid_simplest_models_cfg.models[0].weight == -1);

// case 2:
let valid_cfg = Config::new()
.provider("OPENAI")
.add_model(
ModelConfig::new("gpt-3.5-turbo".to_string())
.build()
.unwrap(),
)
.add_model(ModelConfig::new("gpt-4".to_string()).build().unwrap());
assert!(valid_cfg.build().is_ok());
let mut valid_cfg = Config::default()
.with_provider("OPENAI")
.with_model(ModelConfig::new("gpt-3.5-turbo".to_string()))
.with_model(ModelConfig::new("gpt-4".to_string()));
assert!(valid_cfg.finalize().is_ok());

// case 3:
let invalid_cfg_with_no_api_key = Config::new()
.provider("unknown_provider")
.add_model(ModelConfig::new("some-model".to_string()).build().unwrap());
assert!(invalid_cfg_with_no_api_key.build().is_err());
let mut invalid_cfg_with_no_api_key = Config::default()
.with_provider("unknown_provider")
.with_model(ModelConfig::new("some-model".to_string()));
assert!(invalid_cfg_with_no_api_key.finalize().is_err());

// case 4:
let valid_cfg_with_customized_provider = Config::new()
.base_url("http://example.ai")
.max_output_tokens(2048)
.add_model(
ModelConfig::new("custom-model".to_string())
.provider("FOO")
.build()
.unwrap(),
);
let res = valid_cfg_with_customized_provider.build();
let mut valid_cfg_with_customized_provider = Config::default()
.with_base_url("http://example.ai")
.with_max_output_tokens(2048)
.with_model(ModelConfig::new("custom-model".to_string()).with_provider("FOO"));
let res = valid_cfg_with_customized_provider.finalize();
assert!(res.is_ok());
assert_eq!(
res.clone().unwrap().models[0].base_url.as_ref().unwrap(),
valid_cfg_with_customized_provider.models[0]
.base_url
.as_ref()
.unwrap(),
"http://example.ai"
);
assert_eq!(
res.clone().unwrap().models[0].provider,
valid_cfg_with_customized_provider.models[0].provider,
Some("FOO".to_string())
);
assert_eq!(res.unwrap().models[0].max_output_tokens, Some(2048));
assert_eq!(
valid_cfg_with_customized_provider.models[0].max_output_tokens,
Some(2048)
);

// case 5:
let invalid_empty_models_cfg = Config::new().provider("OPENAI");
assert!(invalid_empty_models_cfg.build().is_err());
let mut invalid_empty_models_cfg = Config::default().with_provider("OPENAI");
assert!(invalid_empty_models_cfg.finalize().is_err());
}
}
6 changes: 2 additions & 4 deletions src/router/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ pub struct RandomRouter {
}

impl RandomRouter {
pub fn new(model_ids: &[ModelId]) -> Self {
Self {
model_ids: model_ids.to_vec(),
}
pub fn new(model_ids: Vec<ModelId>) -> Self {
Self { model_ids }
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/router/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use crate::provider::provider::ResponseRequest;
use crate::router::random::RandomRouter;
use crate::router::weight::WeightedRouter;

pub fn build_router(mode: &RoutingMode, models: &[ModelConfig]) -> Box<dyn Router> {
pub fn build_router(mode: RoutingMode, models: Vec<ModelConfig>) -> Box<dyn Router> {
let model_ids: Vec<ModelId> = models.iter().map(|m| m.id.clone()).collect();
match mode {
RoutingMode::Random => Box::new(RandomRouter::new(&model_ids)),
RoutingMode::Weighted => Box::new(WeightedRouter::new(&model_ids)),
RoutingMode::Random => Box::new(RandomRouter::new(model_ids)),
RoutingMode::Weighted => Box::new(WeightedRouter::new(model_ids)),
}
}

Expand Down
6 changes: 2 additions & 4 deletions src/router/weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ pub struct WeightedRouter {
}

impl WeightedRouter {
pub fn new(model_ids: &[ModelId]) -> Self {
Self {
model_ids: model_ids.to_vec(),
}
pub fn new(model_ids: Vec<ModelId>) -> Self {
Self { model_ids }
}
}

Expand Down
Loading