From 06518315d11aad207e393e2f9c33bbda0ed6119e Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Tue, 17 Jun 2025 14:22:49 +0200 Subject: [PATCH 1/8] wip on adding model routing --- src/agent/monitoring_service.rs | 34 +++++++++++---- src/balancer/proxy_service.rs | 32 ++++++++++++++ src/balancer/request_context.rs | 4 +- src/balancer/status_update.rs | 3 ++ src/balancer/test/mock_status_update.rs | 1 + src/balancer/upstream_peer.rs | 10 +++++ src/balancer/upstream_peer_pool.rs | 14 +++++-- src/cmd/agent.rs | 4 +- src/cmd/balancer.rs | 5 +++ src/llamacpp/llamacpp_client.rs | 55 ++++++++++++++++++++++--- src/llamacpp/mod.rs | 1 + src/llamacpp/models_response.rs | 13 ++++++ src/main.rs | 12 ++++++ 13 files changed, 167 insertions(+), 21 deletions(-) create mode 100644 src/llamacpp/models_response.rs diff --git a/src/agent/monitoring_service.rs b/src/agent/monitoring_service.rs index 598e0ca8..bcc259a0 100644 --- a/src/agent/monitoring_service.rs +++ b/src/agent/monitoring_service.rs @@ -3,6 +3,7 @@ use std::net::SocketAddr; use actix_web::web::Bytes; use async_trait::async_trait; use log::debug; +use log::info; use log::error; #[cfg(unix)] use pingora::server::ListenFds; @@ -23,6 +24,7 @@ pub struct MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Store the check_model flag } impl MonitoringService { @@ -32,6 +34,7 @@ impl MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Include the check_model flag ) -> Result { Ok(MonitoringService { external_llamacpp_addr, @@ -39,19 +42,31 @@ impl MonitoringService { monitoring_interval, name, status_update_tx, + check_model, }) } async fn fetch_status(&self) -> Result { match self.llamacpp_client.get_available_slots().await { - Ok(slots_response) => Ok(StatusUpdate::new( - self.name.to_owned(), - None, - self.external_llamacpp_addr.to_owned(), - slots_response.is_authorized, - slots_response.is_slot_endpoint_enabled, - slots_response.slots, - )), + Ok(slots_response) => { + let model = if self.check_model { + self.llamacpp_client.get_model().await? + } else { + None + }; + + info!("Agent: {:?} Model: {:?}", self.name, model); + + Ok(StatusUpdate::new( + self.name.to_owned(), + None, + self.external_llamacpp_addr.to_owned(), + slots_response.is_authorized, + slots_response.is_slot_endpoint_enabled, + slots_response.slots, + model, + )) + }, Err(err) => Ok(StatusUpdate::new( self.name.to_owned(), Some(err.to_string()), @@ -59,6 +74,7 @@ impl MonitoringService { None, None, vec![], + None, )), } } @@ -111,4 +127,4 @@ impl Service for MonitoringService { fn threads(&self) -> Option { Some(1) } -} +} \ No newline at end of file diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index 41f11e77..8ec1157c 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -4,6 +4,7 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use log::error; +use log::info; use pingora::http::RequestHeader; use pingora::proxy::ProxyHttp; use pingora::proxy::Session; @@ -16,6 +17,7 @@ use crate::balancer::upstream_peer_pool::UpstreamPeerPool; pub struct ProxyService { rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, } @@ -23,11 +25,13 @@ pub struct ProxyService { impl ProxyService { pub fn new( rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, ) -> Self { Self { rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool, } @@ -44,6 +48,7 @@ impl ProxyHttp for ProxyService { slot_taken: false, upstream_peer_pool: self.upstream_peer_pool.clone(), uses_slots: false, + requested_model: Some("".to_string()), } } @@ -135,6 +140,33 @@ impl ProxyHttp for ProxyService { session: &mut Session, ctx: &mut Self::CTX, ) -> Result> { + info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); + + // Check if the request method is POST and the content type is JSON + if self.check_model { + if session.req_header().method == "POST" { + // Check if the content type is application/json + if let Some(content_type) = session.get_header("Content-Type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.contains("application/json") { + // Read the request body + let body = session.read_request_body().await?; + if let Some(body) = body { + // Parse the JSON payload into a serde_json::Value + if let Ok(json_value) = serde_json::from_slice::(&body) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + // Set the requested_model field in the RequestContext + ctx.requested_model = Some(model.to_string()); + info!("Model in request: {:?}", ctx.requested_model); + } + } + } + } + } + } + } + } + let upstream_peer = ctx.select_upstream_peer(session.req_header().uri.path(), self.slots_endpoint_enable); diff --git a/src/balancer/request_context.rs b/src/balancer/request_context.rs index 45acb197..e64a04b1 100644 --- a/src/balancer/request_context.rs +++ b/src/balancer/request_context.rs @@ -15,6 +15,7 @@ pub struct RequestContext { pub selected_peer: Option, pub upstream_peer_pool: Arc, pub uses_slots: bool, + pub requested_model: Option, } impl RequestContext { @@ -51,7 +52,7 @@ impl RequestContext { slots_endpoint_enable: bool, ) -> Result> { if self.selected_peer.is_none() { - self.selected_peer = match self.upstream_peer_pool.use_best_peer() { + self.selected_peer = match self.upstream_peer_pool.use_best_peer(self.requested_model.clone()) { Ok(peer) => peer, Err(err) => { error!("Failed to get best peer: {err}"); @@ -115,6 +116,7 @@ mod tests { selected_peer: None, upstream_peer_pool, uses_slots: true, + requested_model: Some("llama3".to_string()), } } diff --git a/src/balancer/status_update.rs b/src/balancer/status_update.rs index 46da8fa7..ab3903dd 100644 --- a/src/balancer/status_update.rs +++ b/src/balancer/status_update.rs @@ -14,6 +14,7 @@ pub struct StatusUpdate { pub is_authorized: Option, pub is_slots_endpoint_enabled: Option, pub processing_slots_count: usize, + pub model: Option, slots: Vec, } @@ -25,6 +26,7 @@ impl StatusUpdate { is_authorized: Option, is_slots_endpoint_enabled: Option, slots: Vec, + model: Option, ) -> Self { let idle_slots_count = slots.iter().filter(|slot| !slot.is_processing).count(); @@ -37,6 +39,7 @@ impl StatusUpdate { is_slots_endpoint_enabled, processing_slots_count: slots.len() - idle_slots_count, slots, + model, } } } diff --git a/src/balancer/test/mock_status_update.rs b/src/balancer/test/mock_status_update.rs index b87e5744..d1b53484 100644 --- a/src/balancer/test/mock_status_update.rs +++ b/src/balancer/test/mock_status_update.rs @@ -38,5 +38,6 @@ pub fn mock_status_update( Some(true), Some(true), slots, + Some("llama3".to_string()), ) } diff --git a/src/balancer/upstream_peer.rs b/src/balancer/upstream_peer.rs index b5c0cd0e..35b15e42 100644 --- a/src/balancer/upstream_peer.rs +++ b/src/balancer/upstream_peer.rs @@ -14,6 +14,7 @@ use crate::errors::result::Result; pub struct UpstreamPeer { pub agent_id: String, pub agent_name: Option, + pub model: Option, pub error: Option, pub external_llamacpp_addr: SocketAddr, /// None means undetermined, probably due to an error @@ -39,6 +40,7 @@ impl UpstreamPeer { is_slots_endpoint_enabled: Option, slots_idle: usize, slots_processing: usize, + model: Option, ) -> Self { UpstreamPeer { agent_id, @@ -53,6 +55,7 @@ impl UpstreamPeer { slots_processing, slots_taken: 0, slots_taken_since_last_status_update: 0, + model, } } @@ -66,6 +69,7 @@ impl UpstreamPeer { status_update.is_slots_endpoint_enabled, status_update.idle_slots_count, status_update.processing_slots_count, + status_update.model, ) } @@ -76,6 +80,10 @@ impl UpstreamPeer { && matches!(self.is_authorized, Some(true)) } + pub fn is_usable_for_model(&self, requested_model: &str) -> bool { + self.is_usable() && (requested_model.is_empty() || self.model.as_deref() == Some(requested_model)) + } + pub fn release_slot(&mut self) -> Result<()> { if self.slots_taken < 1 { return Err("Cannot release a slot when there are no taken slots".into()); @@ -166,6 +174,7 @@ mod tests { Some(true), 5, 0, + Some("llama3".to_string()), ) } @@ -219,6 +228,7 @@ mod tests { Some(true), Some(true), vec![], + Some("llama3".to_string()), ); peer.update_status(status_update); diff --git a/src/balancer/upstream_peer_pool.rs b/src/balancer/upstream_peer_pool.rs index 4aa4beb3..6d852cc2 100644 --- a/src/balancer/upstream_peer_pool.rs +++ b/src/balancer/upstream_peer_pool.rs @@ -1,6 +1,7 @@ use std::sync::RwLock; use std::time::Duration; use std::time::SystemTime; +use log::info; use serde::Deserialize; use serde::Serialize; @@ -117,14 +118,19 @@ impl UpstreamPeerPool { }) } - pub fn use_best_peer(&self) -> Result> { + pub fn use_best_peer(&self, model: Option) -> Result> { self.with_agents_write(|agents| { for peer in agents.iter() { - if peer.is_usable() { + let model_str = model.as_deref().unwrap_or(""); + let is_usable = peer.is_usable(); + let is_usable_for_model = peer.is_usable_for_model(model_str); + + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); + + if is_usable && (model.is_none() || is_usable_for_model) { return Ok(Some(peer.clone())); } } - Ok(None) }) } @@ -227,7 +233,7 @@ mod tests { pool.register_status_update("test2", mock_status_update("test2", 3, 0))?; pool.register_status_update("test3", mock_status_update("test3", 0, 0))?; - let best_peer = pool.use_best_peer()?.unwrap(); + let best_peer = pool.use_best_peer(None)?.unwrap(); assert_eq!(best_peer.agent_id, "test1"); assert_eq!(best_peer.slots_idle, 5); diff --git a/src/cmd/agent.rs b/src/cmd/agent.rs index b6eedb87..b9c5dd38 100644 --- a/src/cmd/agent.rs +++ b/src/cmd/agent.rs @@ -18,6 +18,7 @@ pub fn handle( management_addr: SocketAddr, monitoring_interval: Duration, name: Option, + check_model: bool, // Include the check_model flag ) -> Result<()> { let (status_update_tx, _status_update_rx) = channel::(1); @@ -29,6 +30,7 @@ pub fn handle( monitoring_interval, name, status_update_tx.clone(), + check_model, // Pass the check_model flag )?; let reporting_service = ReportingService::new(management_addr, status_update_tx)?; @@ -45,4 +47,4 @@ pub fn handle( pingora_server.add_service(monitoring_service); pingora_server.add_service(reporting_service); pingora_server.run_forever(); -} +} \ No newline at end of file diff --git a/src/cmd/balancer.rs b/src/cmd/balancer.rs index 695778db..1365b9b5 100644 --- a/src/cmd/balancer.rs +++ b/src/cmd/balancer.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; #[cfg(feature = "statsd_reporter")] use std::time::Duration; +use log::info; use pingora::proxy::http_proxy_service; use pingora::server::configuration::Opt; @@ -19,6 +20,7 @@ pub fn handle( #[cfg(feature = "web_dashboard")] management_dashboard_enable: bool, reverseproxy_addr: &SocketAddr, rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, #[cfg(feature = "statsd_reporter")] statsd_addr: Option, #[cfg(feature = "statsd_reporter")] statsd_prefix: String, @@ -40,6 +42,7 @@ pub fn handle( &pingora_server.configuration, ProxyService::new( rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool.clone(), ), @@ -67,5 +70,7 @@ pub fn handle( pingora_server.add_service(statsd_service); } + info!("rewrite_host_header? {} check_model? {} slots_endpoint_enable? {}", rewrite_host_header, check_model, slots_endpoint_enable); + pingora_server.run_forever(); } diff --git a/src/llamacpp/llamacpp_client.rs b/src/llamacpp/llamacpp_client.rs index 48514512..6df9e6e3 100644 --- a/src/llamacpp/llamacpp_client.rs +++ b/src/llamacpp/llamacpp_client.rs @@ -6,12 +6,14 @@ use reqwest::header; use url::Url; use crate::errors::result::Result; -use crate::llamacpp::slot::Slot; use crate::llamacpp::slots_response::SlotsResponse; +use crate::llamacpp::slot::Slot; +use crate::llamacpp::models_response::ModelsResponse; pub struct LlamacppClient { client: reqwest::Client, slots_endpoint_url: String, + models_endpoint_url: String, } impl LlamacppClient { @@ -37,6 +39,7 @@ impl LlamacppClient { Ok(Self { client: builder.build()?, slots_endpoint_url: Url::parse(&format!("http://{addr}/slots"))?.to_string(), + models_endpoint_url: Url::parse(&format!("http://{addr}/v1/models"))?.to_string(), }) } @@ -61,11 +64,14 @@ impl LlamacppClient { }; match response.status() { - reqwest::StatusCode::OK => Ok(SlotsResponse { - is_authorized: Some(true), - is_slot_endpoint_enabled: Some(true), - slots: response.json::>().await?, - }), + reqwest::StatusCode::OK => { + let slots: Vec = response.json().await?; + Ok(SlotsResponse { + is_authorized: Some(true), + is_slot_endpoint_enabled: Some(true), + slots, + }) + }, reqwest::StatusCode::UNAUTHORIZED => Ok(SlotsResponse { is_authorized: Some(false), is_slot_endpoint_enabled: None, @@ -79,4 +85,41 @@ impl LlamacppClient { _ => Err("Unexpected response status".into()), } } + + pub async fn get_model(&self) -> Result> { + let url = self.models_endpoint_url.to_owned(); + + let response = match self.client.get(url.clone()).send().await { + Ok(resp) => resp, + Err(err) => { + return Err(format!( + "Request to '{}' failed: '{}'; connect issue: {}; decode issue: {}; request issue: {}; status issue: {}; status: {:?}; source: {:?}", + url, + err, + err.is_connect(), + err.is_decode(), + err.is_request(), + err.is_status(), + err.status(), + err.source() + ).into()); + } + }; + + match response.status() { + reqwest::StatusCode::OK => { + let models_response: ModelsResponse = response.json().await?; + if let Some(models) = models_response.models { + if models.is_empty() { + Ok(None) + } else { + Ok(models.first().and_then(|m| Some(m.model.clone()))) + } + } else { + Ok(None) + } + }, + _ => Err("Unexpected response status".into()), + } + } } diff --git a/src/llamacpp/mod.rs b/src/llamacpp/mod.rs index 2b1e1d24..68f5cae8 100644 --- a/src/llamacpp/mod.rs +++ b/src/llamacpp/mod.rs @@ -1,3 +1,4 @@ pub mod llamacpp_client; pub mod slot; pub mod slots_response; +pub mod models_response; diff --git a/src/llamacpp/models_response.rs b/src/llamacpp/models_response.rs new file mode 100644 index 00000000..d7d04b50 --- /dev/null +++ b/src/llamacpp/models_response.rs @@ -0,0 +1,13 @@ +// paddler/src/llamacpp/models_response.rs +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct ModelsResponse { + pub models: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct Model { + pub model: String, + // Add other fields as needed +} diff --git a/src/main.rs b/src/main.rs index 93999f39..f27fe23d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -80,6 +80,10 @@ enum Commands { #[arg(long)] /// Name of the agent (optional) name: Option, + + #[arg(long)] + /// Flag whether to check the model served by llama.cpp and reject requests for other models + check_model: bool, }, /// Balances incoming requests to llama.cpp instances and optionally provides a web dashboard Balancer { @@ -105,6 +109,10 @@ enum Commands { /// Enable the slots endpoint (not recommended) slots_endpoint_enable: bool, + #[arg(long)] + /// Flag to check the model served by llama.cpp and reject requests for other models + check_model: bool, + #[cfg(feature = "statsd_reporter")] #[arg(long, value_parser = parse_socket_addr)] /// Address of the statsd server to report metrics to @@ -142,6 +150,7 @@ fn main() -> Result<()> { management_addr, monitoring_interval, name, + check_model, }) => cmd::agent::handle( match external_llamacpp_addr { Some(addr) => addr.to_owned(), @@ -152,6 +161,7 @@ fn main() -> Result<()> { management_addr.to_owned(), monitoring_interval.to_owned(), name.to_owned(), + *check_model ), Some(Commands::Balancer { management_addr, @@ -159,6 +169,7 @@ fn main() -> Result<()> { management_dashboard_enable, reverseproxy_addr, rewrite_host_header, + check_model, slots_endpoint_enable, #[cfg(feature = "statsd_reporter")] statsd_addr, @@ -172,6 +183,7 @@ fn main() -> Result<()> { management_dashboard_enable.to_owned(), reverseproxy_addr, rewrite_host_header.to_owned(), + *check_model, slots_endpoint_enable.to_owned(), #[cfg(feature = "statsd_reporter")] statsd_addr.to_owned(), From 84a88f8af9c2d6efcc4e61a88537e81cebce3b54 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Tue, 17 Jun 2025 14:30:52 +0200 Subject: [PATCH 2/8] revert unnecessary change --- src/llamacpp/llamacpp_client.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/llamacpp/llamacpp_client.rs b/src/llamacpp/llamacpp_client.rs index 6df9e6e3..cc9e7e10 100644 --- a/src/llamacpp/llamacpp_client.rs +++ b/src/llamacpp/llamacpp_client.rs @@ -64,14 +64,11 @@ impl LlamacppClient { }; match response.status() { - reqwest::StatusCode::OK => { - let slots: Vec = response.json().await?; - Ok(SlotsResponse { - is_authorized: Some(true), - is_slot_endpoint_enabled: Some(true), - slots, - }) - }, + reqwest::StatusCode::OK => Ok(SlotsResponse { + is_authorized: Some(true), + is_slot_endpoint_enabled: Some(true), + slots: response.json::>().await?, + }), reqwest::StatusCode::UNAUTHORIZED => Ok(SlotsResponse { is_authorized: Some(false), is_slot_endpoint_enabled: None, From 4b4e3f6f6393fd010ccdefcb782b90c5c22aeeeb Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Tue, 17 Jun 2025 16:31:07 +0200 Subject: [PATCH 3/8] avoid consuming the request body for upstream --- src/balancer/proxy_service.rs | 17 ++++++++++++----- src/balancer/upstream_peer_pool.rs | 4 ++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index 8ec1157c..987f1c02 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -149,17 +149,24 @@ impl ProxyHttp for ProxyService { if let Some(content_type) = session.get_header("Content-Type") { if let Ok(content_type_str) = content_type.to_str() { if content_type_str.contains("application/json") { - // Read the request body - let body = session.read_request_body().await?; - if let Some(body) = body { - // Parse the JSON payload into a serde_json::Value - if let Ok(json_value) = serde_json::from_slice::(&body) { + // Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028 + session.enable_retry_buffering(); + session.read_body_or_idle(false).await.unwrap().unwrap(); + let request_body = session.get_retry_buffer(); + + // Parse the JSON payload into a serde_json::Value + if let Some(body_bytes) = request_body { + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { // Set the requested_model field in the RequestContext ctx.requested_model = Some(model.to_string()); info!("Model in request: {:?}", ctx.requested_model); } + } else { + error!("Failed to parse JSON payload"); } + } else { + error!("Request body is None"); } } } diff --git a/src/balancer/upstream_peer_pool.rs b/src/balancer/upstream_peer_pool.rs index 6d852cc2..101c1b31 100644 --- a/src/balancer/upstream_peer_pool.rs +++ b/src/balancer/upstream_peer_pool.rs @@ -125,12 +125,12 @@ impl UpstreamPeerPool { let is_usable = peer.is_usable(); let is_usable_for_model = peer.is_usable_for_model(model_str); - info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); - if is_usable && (model.is_none() || is_usable_for_model) { + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); return Ok(Some(peer.clone())); } } + Ok(None) }) } From b29c875c9348e464df9bae091cd68dde4935f9eb Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Fri, 20 Jun 2025 17:21:32 +0200 Subject: [PATCH 4/8] add responses for unsupported/missing model parameter, add regex check to handle only first 64k characters available in buffer --- Cargo.lock | 1 + Cargo.toml | 1 + src/balancer/proxy_service.rs | 103 +++++++++++++++++++++----------- src/balancer/request_context.rs | 19 +++++- 4 files changed, 88 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 60032252..bc4ae624 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2449,6 +2449,7 @@ dependencies = [ "mime_guess", "pingora", "ratatui", + "regex", "reqwest", "rust-embed", "serde", diff --git a/Cargo.toml b/Cargo.toml index 78bd06c5..172d4c6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ itertools = "0.13.0" color-eyre = "0.6.5" time = "0.3.41" chrono = "0.4.41" +regex = "1.11.1" [features] default = ["statsd_reporter", "ratatui_dashboard"] diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index 9c59f15b..9a89fb13 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -149,40 +149,6 @@ impl ProxyHttp for ProxyService { session: &mut Session, ctx: &mut Self::CTX, ) -> Result> { - info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); - - // Check if the request method is POST and the content type is JSON - if self.check_model { - if session.req_header().method == "POST" { - // Check if the content type is application/json - if let Some(content_type) = session.get_header("Content-Type") { - if let Ok(content_type_str) = content_type.to_str() { - if content_type_str.contains("application/json") { - // Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028 - session.enable_retry_buffering(); - session.read_body_or_idle(false).await.unwrap().unwrap(); - let request_body = session.get_retry_buffer(); - - // Parse the JSON payload into a serde_json::Value - if let Some(body_bytes) = request_body { - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { - // Set the requested_model field in the RequestContext - ctx.requested_model = Some(model.to_string()); - info!("Model in request: {:?}", ctx.requested_model); - } - } else { - error!("Failed to parse JSON payload"); - } - } else { - error!("Request body is None"); - } - } - } - } - } - } - let Some(_req_guard) = RequestBufferGuard::increment( &self.upstream_peer_pool.request_buffer_length, self.max_requests, @@ -213,10 +179,79 @@ impl ProxyHttp for ProxyService { } "/chat/completions" => true, "/completion" => true, + "/v1/completions" => true, "/v1/chat/completions" => true, _ => false, }; + info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); + + // Check if the request method is POST and the content type is JSON // cnbREaxdMcQVBS + if self.check_model && ctx.uses_slots { + info!("Checking model..."); + ctx.requested_model = None; + if session.req_header().method == "POST" { + // Check if the content type is application/json + if let Some(content_type) = session.get_header("Content-Type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.contains("application/json") { + // Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028 + session.enable_retry_buffering(); + session.read_body_or_idle(false).await.unwrap().unwrap(); + let request_body = session.get_retry_buffer(); + + // Parse the JSON payload into a serde_json::Value + if let Some(body_bytes) = request_body { + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + // Set the requested_model field in the RequestContext + ctx.requested_model = Some(model.to_string()); + info!("Model in request: {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload, trying regex extraction"); + + // Try extracting the model using regex + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex: {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex"); + } + } + } else { + info!("Request body is None"); + } + } + } + } + } + // abort if model has not been set + if ctx.requested_model == None { + info!("Model missing in request"); + session + .respond_error(pingora::http::StatusCode::BAD_REQUEST.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else if ctx.has_peer_supporting_model() == false { + info!("Model {:?} not supported by upstream", ctx.requested_model); + session + .respond_error(pingora::http::StatusCode::NOT_FOUND.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else { + info!("Model {:?}", ctx.requested_model); + } + } + let peer = tokio::select! { result = async { loop { diff --git a/src/balancer/request_context.rs b/src/balancer/request_context.rs index 47abb9ed..aedcd10d 100644 --- a/src/balancer/request_context.rs +++ b/src/balancer/request_context.rs @@ -60,6 +60,21 @@ impl RequestContext { ) } + pub fn has_peer_supporting_model(&self) -> bool { + let model_str = self.requested_model.as_deref().unwrap_or(""); + match self.upstream_peer_pool.with_agents_read(|agents| { + for peer in agents.iter() { + if peer.is_usable_for_model(model_str) { + return Ok(true); + } + } + Ok(false) + }) { + Ok(result) => result, + Err(_) => false, // or handle the error as needed + } + } + pub fn select_upstream_peer(&mut self) -> Result<()> { let result_option_peer = if self.uses_slots && !self.slot_taken { self.use_best_peer_and_take_slot(self.requested_model.clone()) @@ -114,7 +129,7 @@ mod tests { pool.register_status_update("test_agent", mock_status_update("test_agent", 0, 0))?; - assert!(ctx.use_best_peer_and_take_slot().unwrap().is_none()); + assert!(ctx.use_best_peer_and_take_slot(ctx.requested_model.clone()).unwrap().is_none()); assert!(!ctx.slot_taken); assert_eq!(ctx.selected_peer, None); @@ -139,7 +154,7 @@ mod tests { "127.0.0.1:8080" ); - ctx.use_best_peer_and_take_slot()?; + ctx.use_best_peer_and_take_slot(ctx.requested_model.clone())?; assert!(ctx.slot_taken); From b328bab9c9bc6965f7b82aee1e437a080ee51469 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Mon, 23 Jun 2025 14:54:42 +0200 Subject: [PATCH 5/8] fix mock_status_update.rs --- src/balancer/test/mock_status_update.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/balancer/test/mock_status_update.rs b/src/balancer/test/mock_status_update.rs index 2f22e95e..7d61139f 100644 --- a/src/balancer/test/mock_status_update.rs +++ b/src/balancer/test/mock_status_update.rs @@ -46,7 +46,7 @@ pub fn mock_status_update( is_authorized: Some(true), is_slots_endpoint_enabled: Some(true), processing_slots_count: slots.len() - idle_slots_count, + model: Some("llama3".to_string()), slots: slots, - Some("llama3".to_string()), } } From c2efc634edf909f7c6c0ab40fa50de019cc6b96e Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Fri, 27 Jun 2025 15:50:55 +0200 Subject: [PATCH 6/8] fix returning model not supported on all slots for model currently blocked --- src/balancer/proxy_service.rs | 2 +- src/balancer/request_context.rs | 2 +- src/balancer/upstream_peer.rs | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index 2e75e6b4..eea439bd 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -192,7 +192,7 @@ impl ProxyHttp for ProxyService { info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); - // Check if the request method is POST and the content type is JSON // cnbREaxdMcQVBS + // Check if the request method is POST and the content type is JSON if self.check_model && ctx.uses_slots { info!("Checking model..."); ctx.requested_model = None; diff --git a/src/balancer/request_context.rs b/src/balancer/request_context.rs index 70bc042a..ef03ce48 100644 --- a/src/balancer/request_context.rs +++ b/src/balancer/request_context.rs @@ -61,7 +61,7 @@ impl RequestContext { let model_str = self.requested_model.as_deref().unwrap_or(""); match self.upstream_peer_pool.with_agents_read(|agents| { for peer in agents.iter() { - if peer.is_usable_for_model(model_str) { + if peer.supports_model(model_str) { return Ok(true); } } diff --git a/src/balancer/upstream_peer.rs b/src/balancer/upstream_peer.rs index 6a004e77..d73f6a6e 100644 --- a/src/balancer/upstream_peer.rs +++ b/src/balancer/upstream_peer.rs @@ -37,6 +37,10 @@ impl UpstreamPeer { !self.status.has_issues() && self.status.slots_idle > 0 && self.quarantined_until.is_none() } + pub fn supports_model(&self, requested_model: &str) -> bool { + requested_model.is_empty() || self.model.as_deref() == Some(requested_model) + } + pub fn is_usable_for_model(&self, requested_model: &str) -> bool { self.is_usable() && (requested_model.is_empty() || self.model.as_deref() == Some(requested_model)) } From eb0c275e1606d7675690a18293b0454f6691f109 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Tue, 1 Jul 2025 11:00:58 +0200 Subject: [PATCH 7/8] make model nullable for dashboard --- resources/ts/schemas/Agent.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resources/ts/schemas/Agent.ts b/resources/ts/schemas/Agent.ts index 4a1fef54..c2eea84d 100644 --- a/resources/ts/schemas/Agent.ts +++ b/resources/ts/schemas/Agent.ts @@ -5,7 +5,7 @@ import { StatusUpdateSchema } from "./StatusUpdate"; export const AgentSchema = z .object({ agent_id: z.string(), - model: z.string(), + model: z.string().nullable(), last_update: z.object({ nanos_since_epoch: z.number(), secs_since_epoch: z.number(), From 961d4b4bbff8210e4a23d6ddfb3d763300ae35a8 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Sat, 12 Jul 2025 10:44:09 +0200 Subject: [PATCH 8/8] add handling of non-utf8 chunk in first 64k bytes of request body --- src/balancer/proxy_service.rs | 67 +++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index eea439bd..9eeb1a54 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -206,27 +206,56 @@ impl ProxyHttp for ProxyService { session.read_body_or_idle(false).await.unwrap().unwrap(); let request_body = session.get_retry_buffer(); - // Parse the JSON payload into a serde_json::Value if let Some(body_bytes) = request_body { - if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { - if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { - // Set the requested_model field in the RequestContext - ctx.requested_model = Some(model.to_string()); - info!("Model in request: {:?}", ctx.requested_model); - } - } else { - info!("Failed to parse JSON payload, trying regex extraction"); - - // Try extracting the model using regex - let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); - let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); - if let Some(caps) = re.captures(&body_str) { - if let Some(model) = caps.get(1) { - ctx.requested_model = Some(model.as_str().to_string()); - info!("Model via regex: {:?}", ctx.requested_model); + match std::str::from_utf8(&body_bytes) { + Ok(_) => { + // The bytes are valid UTF-8, proceed as normal + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request: {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload, trying regex extraction"); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex: {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex"); + } + } + }, + Err(e) => { + // Invalid UTF-8 detected. Truncate to the last valid UTF-8 boundary. + let valid_up_to = e.valid_up_to(); + info!("Invalid UTF-8 detected. Truncating from {} bytes to {} bytes.", body_bytes.len(), valid_up_to); + + // Create a new `Bytes` slice containing only the valid UTF-8 part. + let valid_body_bytes = body_bytes.slice(0..valid_up_to); + + // Now proceed with the (truncated) valid_body_bytes + if let Ok(json_value) = serde_json::from_slice::(&valid_body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload (after truncation), trying regex extraction"); + let body_str = String::from_utf8_lossy(&valid_body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex (after truncation)"); + } } - } else { - info!("Failed to extract model using regex"); } } } else {