diff --git a/Cargo.lock b/Cargo.lock index 5479934f..79042af7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2526,6 +2526,7 @@ dependencies = [ "mime_guess", "pingora", "ratatui", + "regex", "reqwest", "rust-embed", "serde", diff --git a/Cargo.toml b/Cargo.toml index 14cd96b0..e20218b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ serde_json = "1.0.140" tokio = { version = "1.45.1", features = ["full"] } tokio-stream = { version = "0.1.17", features = ["sync"] } url = { version = "2.5.4", features = ["serde"] } +regex = "1.11.1" chrono = { version = "0.4.41", optional = true } crossterm = { version = "0.28.1", features = ["event-stream"], optional = true } diff --git a/resources/ts/components/AgentsList.tsx b/resources/ts/components/AgentsList.tsx index 8b41aee9..da927386 100644 --- a/resources/ts/components/AgentsList.tsx +++ b/resources/ts/components/AgentsList.tsx @@ -21,6 +21,7 @@ export function AgentsList({ agents }: { agents: Array }) { Name + Model Issues Llama.cpp address Last update @@ -54,6 +55,7 @@ export function AgentsList({ agents }: { agents: Array }) { key={agent_id} > {status.agent_name} + {status.model} {status.error && ( <> diff --git a/resources/ts/schemas/Agent.ts b/resources/ts/schemas/Agent.ts index 678bd124..c2eea84d 100644 --- a/resources/ts/schemas/Agent.ts +++ b/resources/ts/schemas/Agent.ts @@ -5,6 +5,7 @@ import { StatusUpdateSchema } from "./StatusUpdate"; export const AgentSchema = z .object({ agent_id: z.string(), + model: z.string().nullable(), last_update: z.object({ nanos_since_epoch: z.number(), secs_since_epoch: z.number(), diff --git a/resources/ts/schemas/StatusUpdate.ts b/resources/ts/schemas/StatusUpdate.ts index 7543d925..52747e5b 100644 --- a/resources/ts/schemas/StatusUpdate.ts +++ b/resources/ts/schemas/StatusUpdate.ts @@ -14,6 +14,7 @@ export const StatusUpdateSchema = z is_unexpected_response_status: z.boolean().nullable(), slots_idle: z.number(), slots_processing: z.number(), + model: z.string().nullable(), }) .strict(); diff --git a/src/agent/monitoring_service.rs b/src/agent/monitoring_service.rs index 4837493f..1eb14658 100644 --- a/src/agent/monitoring_service.rs +++ b/src/agent/monitoring_service.rs @@ -23,6 +23,7 @@ pub struct MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Store the check_model flag } impl MonitoringService { @@ -32,6 +33,7 @@ impl MonitoringService { monitoring_interval: Duration, name: Option, status_update_tx: Sender, + check_model: bool, // Include the check_model flag ) -> Result { Ok(MonitoringService { external_llamacpp_addr, @@ -39,6 +41,7 @@ impl MonitoringService { monitoring_interval, name, status_update_tx, + check_model, }) } @@ -50,6 +53,15 @@ impl MonitoringService { .filter(|slot| slot.is_processing) .count(); + let model: Option = if self.check_model { + match self.llamacpp_client.get_model().await { + Ok(model) => model, + Err(_) => None, + } + } else { + Some("".to_string()) + }; + StatusUpdate { agent_name: self.name.to_owned(), error: slots_response.error, @@ -63,6 +75,7 @@ impl MonitoringService { is_unexpected_response_status: slots_response.is_unexpected_response_status, slots_idle: slots_response.slots.len() - slots_processing, slots_processing, + model, } } @@ -109,4 +122,4 @@ impl Service for MonitoringService { fn threads(&self) -> Option { Some(1) } -} +} \ No newline at end of file diff --git a/src/balancer/proxy_service.rs b/src/balancer/proxy_service.rs index cfab2dfd..9eeb1a54 100644 --- a/src/balancer/proxy_service.rs +++ b/src/balancer/proxy_service.rs @@ -6,6 +6,7 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use log::error; +use log::info; use pingora::http::RequestHeader; use pingora::proxy::ProxyHttp; use pingora::proxy::Session; @@ -41,6 +42,7 @@ pub struct ProxyService { buffered_request_timeout: Duration, max_buffered_requests: usize, rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, } @@ -48,6 +50,7 @@ pub struct ProxyService { impl ProxyService { pub fn new( rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, upstream_peer_pool: Arc, buffered_request_timeout: Duration, @@ -55,6 +58,7 @@ impl ProxyService { ) -> Self { Self { rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool, buffered_request_timeout, @@ -73,6 +77,7 @@ impl ProxyHttp for ProxyService { slot_taken: false, upstream_peer_pool: self.upstream_peer_pool.clone(), uses_slots: false, + requested_model: Some("".to_string()), } } @@ -180,10 +185,108 @@ impl ProxyHttp for ProxyService { } "/chat/completions" => true, "/completion" => true, + "/v1/completions" => true, "/v1/chat/completions" => true, _ => false, }; + info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model); + + // Check if the request method is POST and the content type is JSON + if self.check_model && ctx.uses_slots { + info!("Checking model..."); + ctx.requested_model = None; + if session.req_header().method == "POST" { + // Check if the content type is application/json + if let Some(content_type) = session.get_header("Content-Type") { + if let Ok(content_type_str) = content_type.to_str() { + if content_type_str.contains("application/json") { + // Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028 + session.enable_retry_buffering(); + session.read_body_or_idle(false).await.unwrap().unwrap(); + let request_body = session.get_retry_buffer(); + + if let Some(body_bytes) = request_body { + match std::str::from_utf8(&body_bytes) { + Ok(_) => { + // The bytes are valid UTF-8, proceed as normal + if let Ok(json_value) = serde_json::from_slice::(&body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request: {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload, trying regex extraction"); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex: {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex"); + } + } + }, + Err(e) => { + // Invalid UTF-8 detected. Truncate to the last valid UTF-8 boundary. + let valid_up_to = e.valid_up_to(); + info!("Invalid UTF-8 detected. Truncating from {} bytes to {} bytes.", body_bytes.len(), valid_up_to); + + // Create a new `Bytes` slice containing only the valid UTF-8 part. + let valid_body_bytes = body_bytes.slice(0..valid_up_to); + + // Now proceed with the (truncated) valid_body_bytes + if let Ok(json_value) = serde_json::from_slice::(&valid_body_bytes) { + if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) { + ctx.requested_model = Some(model.to_string()); + info!("Model in request (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to parse JSON payload (after truncation), trying regex extraction"); + let body_str = String::from_utf8_lossy(&valid_body_bytes).to_string(); + let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap(); + if let Some(caps) = re.captures(&body_str) { + if let Some(model) = caps.get(1) { + ctx.requested_model = Some(model.as_str().to_string()); + info!("Model via regex (after truncation): {:?}", ctx.requested_model); + } + } else { + info!("Failed to extract model using regex (after truncation)"); + } + } + } + } + } else { + info!("Request body is None"); + } + } + } + } + } + // abort if model has not been set + if ctx.requested_model == None { + info!("Model missing in request"); + session + .respond_error(pingora::http::StatusCode::BAD_REQUEST.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else if ctx.has_peer_supporting_model() == false { + info!("Model {:?} not supported by upstream", ctx.requested_model); + session + .respond_error(pingora::http::StatusCode::NOT_FOUND.as_u16()) + .await?; + + return Err(Error::new_down(pingora::ErrorType::ConnectRefused)); + } + else { + info!("Model {:?}", ctx.requested_model); + } + } + let peer = tokio::select! { result = async { loop { diff --git a/src/balancer/request_context.rs b/src/balancer/request_context.rs index eb8b1c11..da308717 100644 --- a/src/balancer/request_context.rs +++ b/src/balancer/request_context.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use anyhow::anyhow; use log::error; +use log::info; use pingora::Error; use pingora::Result; @@ -13,6 +14,7 @@ pub struct RequestContext { pub selected_peer: Option, pub upstream_peer_pool: Arc, pub uses_slots: bool, + pub requested_model: Option, } impl RequestContext { @@ -30,16 +32,19 @@ impl RequestContext { } } - pub fn use_best_peer_and_take_slot(&mut self) -> anyhow::Result> { + pub fn use_best_peer_and_take_slot(&mut self, model: Option) -> anyhow::Result> { if let Some(peer) = self.upstream_peer_pool.with_agents_write(|agents| { + let model_str = model.as_deref().unwrap_or(""); for peer in agents.iter_mut() { - if peer.is_usable() { - peer.take_slot()?; + let is_usable = peer.is_usable(); + let is_usable_for_model = peer.is_usable_for_model(model_str); + if is_usable && (model.is_none() || is_usable_for_model) { + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); + peer.take_slot()?; return Ok(Some(peer.clone())); } } - Ok(None) })? { self.upstream_peer_pool.restore_integrity()?; @@ -52,11 +57,26 @@ impl RequestContext { } } + pub fn has_peer_supporting_model(&self) -> bool { + let model_str = self.requested_model.as_deref().unwrap_or(""); + match self.upstream_peer_pool.with_agents_read(|agents| { + for peer in agents.iter() { + if peer.supports_model(model_str) { + return Ok(true); + } + } + Ok(false) + }) { + Ok(result) => result, + Err(_) => false, // or handle the error as needed + } + } + pub fn select_upstream_peer(&mut self) -> Result<()> { let result_option_peer = if self.uses_slots && !self.slot_taken { - self.use_best_peer_and_take_slot() + self.use_best_peer_and_take_slot(self.requested_model.clone()) } else { - self.upstream_peer_pool.use_best_peer() + self.upstream_peer_pool.use_best_peer(self.requested_model.clone()) }; self.selected_peer = match result_option_peer { @@ -95,6 +115,7 @@ mod tests { selected_peer: None, upstream_peer_pool, uses_slots: true, + requested_model: Some("llama3".to_string()), } } @@ -105,7 +126,7 @@ mod tests { pool.register_status_update("test_agent", mock_status_update("test_agent", 0, 0))?; - assert!(ctx.use_best_peer_and_take_slot().unwrap().is_none()); + assert!(ctx.use_best_peer_and_take_slot(ctx.requested_model.clone()).unwrap().is_none()); assert!(!ctx.slot_taken); assert_eq!(ctx.selected_peer, None); diff --git a/src/balancer/status_update.rs b/src/balancer/status_update.rs index 6f2c5409..5e811640 100644 --- a/src/balancer/status_update.rs +++ b/src/balancer/status_update.rs @@ -20,6 +20,7 @@ pub struct StatusUpdate { pub is_unexpected_response_status: Option, pub slots_idle: usize, pub slots_processing: usize, + pub model: Option, } impl StatusUpdate { diff --git a/src/balancer/test/mock_status_update.rs b/src/balancer/test/mock_status_update.rs index 07d12a1f..8cc9da41 100644 --- a/src/balancer/test/mock_status_update.rs +++ b/src/balancer/test/mock_status_update.rs @@ -22,5 +22,6 @@ pub fn mock_status_update( is_unexpected_response_status: Some(false), slots_idle, slots_processing, + model: Some("llama3".to_string()), } } diff --git a/src/balancer/upstream_peer.rs b/src/balancer/upstream_peer.rs index a8eaff95..2c4ccee9 100644 --- a/src/balancer/upstream_peer.rs +++ b/src/balancer/upstream_peer.rs @@ -13,6 +13,7 @@ use crate::balancer::status_update::StatusUpdate; #[derive(Clone, Debug, Eq, Serialize, Deserialize)] pub struct UpstreamPeer { pub agent_id: String, + pub model: Option, pub last_update: SystemTime, pub quarantined_until: Option, pub slots_taken: usize, @@ -24,6 +25,7 @@ impl UpstreamPeer { pub fn new_from_status_update(agent_id: String, status: StatusUpdate) -> Self { Self { agent_id, + model: status.model.clone(), last_update: SystemTime::now(), quarantined_until: None, slots_taken: 0, @@ -36,6 +38,14 @@ impl UpstreamPeer { !self.status.has_issues() && self.status.slots_idle > 0 && self.quarantined_until.is_none() } + pub fn supports_model(&self, requested_model: &str) -> bool { + requested_model.is_empty() || self.model.as_deref() == Some(requested_model) + } + + pub fn is_usable_for_model(&self, requested_model: &str) -> bool { + self.is_usable() && (requested_model.is_empty() || self.model.as_deref() == Some(requested_model)) + } + pub fn release_slot(&mut self) -> Result<()> { if self.slots_taken < 1 { return Err(anyhow!( @@ -59,6 +69,7 @@ impl UpstreamPeer { self.last_update = SystemTime::now(); self.quarantined_until = None; self.slots_taken_since_last_status_update = 0; + self.model = status_update.model.clone(); self.status = status_update; } @@ -110,6 +121,7 @@ mod tests { fn create_test_peer() -> UpstreamPeer { UpstreamPeer { agent_id: "test_agent".to_string(), + model: "llama3".to_string(), last_update: SystemTime::now(), quarantined_until: None, slots_taken: 0, @@ -177,7 +189,6 @@ mod tests { #[test] fn test_update_status() { let mut peer = create_test_peer(); - let slots: Vec = vec![]; let slots_idle = slots.iter().filter(|slot| !slot.is_processing).count(); @@ -194,6 +205,7 @@ mod tests { is_unexpected_response_status: None, slots_idle, slots_processing: slots.len() - slots_idle, + model: Some("llama3".to_string()) }; peer.update_status(status_update); diff --git a/src/balancer/upstream_peer_pool.rs b/src/balancer/upstream_peer_pool.rs index 2e6924a9..4593bab9 100644 --- a/src/balancer/upstream_peer_pool.rs +++ b/src/balancer/upstream_peer_pool.rs @@ -2,6 +2,7 @@ use std::sync::atomic::AtomicUsize; use std::sync::RwLock; use std::time::Duration; use std::time::SystemTime; +use log::info; use anyhow::anyhow; use anyhow::Result; @@ -159,10 +160,15 @@ impl UpstreamPeerPool { }) } - pub fn use_best_peer(&self) -> Result> { - self.with_agents_read(|agents| { + pub fn use_best_peer(&self, model: Option) -> Result> { + self.with_agents_write(|agents| { for peer in agents.iter() { - if peer.is_usable() { + let model_str = model.as_deref().unwrap_or(""); + let is_usable = peer.is_usable(); + let is_usable_for_model = peer.is_usable_for_model(model_str); + + if is_usable && (model.is_none() || is_usable_for_model) { + info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model); return Ok(Some(peer.clone())); } } @@ -263,7 +269,7 @@ mod tests { pool.register_status_update("test2", mock_status_update("test2", 3, 0))?; pool.register_status_update("test3", mock_status_update("test3", 0, 0))?; - let best_peer = pool.use_best_peer()?.unwrap(); + let best_peer = pool.use_best_peer(None)?.unwrap(); assert_eq!(best_peer.agent_id, "test1"); assert_eq!(best_peer.status.slots_idle, 5); diff --git a/src/cmd/agent.rs b/src/cmd/agent.rs index aef1b711..578f5dbe 100644 --- a/src/cmd/agent.rs +++ b/src/cmd/agent.rs @@ -18,6 +18,7 @@ pub fn handle( management_addr: SocketAddr, monitoring_interval: Duration, name: Option, + check_model: bool, // Include the check_model flag ) -> Result<()> { let (status_update_tx, _status_update_rx) = channel::(1); @@ -29,6 +30,7 @@ pub fn handle( monitoring_interval, name, status_update_tx.clone(), + check_model, // Pass the check_model flag )?; let reporting_service = ReportingService::new(management_addr, status_update_tx)?; @@ -45,4 +47,4 @@ pub fn handle( pingora_server.add_service(monitoring_service); pingora_server.add_service(reporting_service); pingora_server.run_forever(); -} +} \ No newline at end of file diff --git a/src/cmd/balancer.rs b/src/cmd/balancer.rs index 4978f0c3..02b7eee8 100644 --- a/src/cmd/balancer.rs +++ b/src/cmd/balancer.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use std::sync::Arc; #[cfg(feature = "statsd_reporter")] use std::time::Duration; +use log::info; use anyhow::Result; use pingora::proxy::http_proxy_service; @@ -23,6 +24,7 @@ pub fn handle( max_buffered_requests: usize, reverseproxy_addr: &SocketAddr, rewrite_host_header: bool, + check_model: bool, slots_endpoint_enable: bool, #[cfg(feature = "statsd_reporter")] statsd_addr: Option, #[cfg(feature = "statsd_reporter")] statsd_prefix: String, @@ -44,6 +46,7 @@ pub fn handle( &pingora_server.configuration, ProxyService::new( rewrite_host_header, + check_model, slots_endpoint_enable, upstream_peer_pool.clone(), buffered_request_timeout, @@ -74,5 +77,7 @@ pub fn handle( pingora_server.add_service(statsd_service); } + info!("rewrite_host_header? {} check_model? {} slots_endpoint_enable? {}", rewrite_host_header, check_model, slots_endpoint_enable); + pingora_server.run_forever(); } diff --git a/src/llamacpp/llamacpp_client.rs b/src/llamacpp/llamacpp_client.rs index 742025f4..4e802164 100644 --- a/src/llamacpp/llamacpp_client.rs +++ b/src/llamacpp/llamacpp_client.rs @@ -1,16 +1,19 @@ use std::net::SocketAddr; use std::time::Duration; +use anyhow::anyhow; use anyhow::Result; use reqwest::header; use url::Url; use crate::llamacpp::slot::Slot; use crate::llamacpp::slots_response::SlotsResponse; +use crate::llamacpp::models_response::ModelsResponse; pub struct LlamacppClient { client: reqwest::Client, slots_endpoint_url: String, + models_endpoint_url: String, } impl LlamacppClient { @@ -36,6 +39,7 @@ impl LlamacppClient { Ok(Self { client: builder.build()?, slots_endpoint_url: Url::parse(&format!("http://{addr}/slots"))?.to_string(), + models_endpoint_url: Url::parse(&format!("http://{addr}/v1/models"))?.to_string(), }) } @@ -115,4 +119,40 @@ impl LlamacppClient { }, } } + + pub async fn get_model(&self) -> Result> { + let url = self.models_endpoint_url.to_owned(); + + let response = match self.client.get(url.clone()).send().await { + Ok(resp) => resp, + Err(err) => { + return Err(anyhow!( + "Request to '{}' failed: '{}'; connect issue: {}; decode issue: {}; request issue: {}; status issue: {}; status: {:?}", + url, + err, + err.is_connect(), + err.is_decode(), + err.is_request(), + err.is_status(), + err.status() + )); + } + }; + + match response.status() { + reqwest::StatusCode::OK => { + let models_response: ModelsResponse = response.json().await?; + if let Some(models) = models_response.models { + if models.is_empty() { + Ok(None) + } else { + Ok(models.first().and_then(|m| Some(m.model.clone()))) + } + } else { + Ok(None) + } + }, + _ => Err(anyhow!("Unexpected response status")), + } + } } diff --git a/src/llamacpp/mod.rs b/src/llamacpp/mod.rs index 2b1e1d24..68f5cae8 100644 --- a/src/llamacpp/mod.rs +++ b/src/llamacpp/mod.rs @@ -1,3 +1,4 @@ pub mod llamacpp_client; pub mod slot; pub mod slots_response; +pub mod models_response; diff --git a/src/llamacpp/models_response.rs b/src/llamacpp/models_response.rs new file mode 100644 index 00000000..d7d04b50 --- /dev/null +++ b/src/llamacpp/models_response.rs @@ -0,0 +1,13 @@ +// paddler/src/llamacpp/models_response.rs +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct ModelsResponse { + pub models: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct Model { + pub model: String, + // Add other fields as needed +} diff --git a/src/main.rs b/src/main.rs index 820ee91e..b936643b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -87,6 +87,10 @@ enum Commands { #[arg(long)] /// Name of the agent (optional) name: Option, + + #[arg(long)] + /// Flag whether to check the model served by llama.cpp and reject requests for other models + check_model: bool, }, /// Balances incoming requests to llama.cpp instances and optionally provides a web dashboard Balancer { @@ -130,6 +134,10 @@ enum Commands { /// Enable the slots endpoint (not recommended) slots_endpoint_enable: bool, + #[arg(long)] + /// Flag to check the model served by llama.cpp and reject requests for other models + check_model: bool, + #[cfg(feature = "statsd_reporter")] #[arg(long, value_parser = parse_socket_addr)] /// Address of the statsd server to report metrics to @@ -167,6 +175,7 @@ fn main() -> Result<()> { management_addr, monitoring_interval, name, + check_model, }) => cmd::agent::handle( match external_llamacpp_addr { Some(addr) => addr.to_owned(), @@ -177,6 +186,7 @@ fn main() -> Result<()> { management_addr.to_owned(), monitoring_interval.to_owned(), name.to_owned(), + *check_model ), Some(Commands::Balancer { buffered_request_timeout, @@ -187,6 +197,7 @@ fn main() -> Result<()> { max_buffered_requests, reverseproxy_addr, rewrite_host_header, + check_model, slots_endpoint_enable, #[cfg(feature = "statsd_reporter")] statsd_addr, @@ -207,6 +218,7 @@ fn main() -> Result<()> { *max_buffered_requests, reverseproxy_addr, rewrite_host_header.to_owned(), + *check_model, slots_endpoint_enable.to_owned(), #[cfg(feature = "statsd_reporter")] statsd_addr.to_owned(),