From b80c7a1ebf7bb81fc5452f7f3005c59db97a2d87 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Tue, 10 Dec 2024 23:09:58 -0800 Subject: [PATCH 01/27] fix: Using 2 structs can remove some is_none Some checks Remove unrepresentable state Also in the pure rust case, i believe this means the original connection is consumed, and not re-usable , so you must use the new one. In python due to the need for a clone im not sure this holds --- .../client/examples/python_client_example.py | 8 +- .../client/examples/rust_client_example.rs | 6 +- rust/lib/srpc/client/src/lib.rs | 107 +++++++++--------- rust/lib/srpc/client/src/python_bindings.rs | 20 +++- rust/lib/srpc/client/srpc_client.pyi | 12 ++ 5 files changed, 88 insertions(+), 65 deletions(-) create mode 100644 rust/lib/srpc/client/srpc_client.pyi diff --git a/rust/lib/srpc/client/examples/python_client_example.py b/rust/lib/srpc/client/examples/python_client_example.py index 11bc6b1c..cb9acdb2 100644 --- a/rust/lib/srpc/client/examples/python_client_example.py +++ b/rust/lib/srpc/client/examples/python_client_example.py @@ -9,12 +9,12 @@ import asyncio import json -from srpc_client import SrpcClient +from srpc_client import SrpcClientConfig async def main(): - client = SrpcClient("", 6976, "/_SRPC_/TLS/JSON", "", "") - - await client.connect() + client = SrpcClientConfig("", 6976, "/_SRPC_/TLS/JSON", "", "") + + client = await client.connect() print("Connected to server") message = "Hypervisor.StartVm\n" diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 2e723dbb..97a4a462 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,11 +1,11 @@ -use srpc_client::Client; +use srpc_client::ClientConfig; use tokio; use serde_json::json; #[tokio::main] async fn main() -> Result<(), Box> { // Create a new Client instance - let client = Client::new( + let client = ClientConfig::new( "", 6976, "/_SRPC_/TLS/JSON", @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box> { ); // Connect to the server - client.connect().await?; + let client = client.connect().await?; println!("Connected to server"); // Send a message diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 46af1760..69d8dfc7 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -22,28 +22,33 @@ impl fmt::Display for CustomError { impl Error for CustomError {} -pub struct Client { +#[derive(Clone)] +pub struct ClientConfig { host: String, port: u16, path: String, cert: String, key: String, - stream: Arc>>>, } -impl Client { +pub struct ConnectedClient { + pub connection_params: ClientConfig, + stream: Connected, +} + +type Connected = Arc>>; +impl ClientConfig { pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { - Client { + ClientConfig { host: host.to_string(), port, path: path.to_string(), cert: cert.to_string(), key: key.to_string(), - stream: Arc::new(Mutex::new(None)), } } - pub async fn connect(&self) -> Result<(), Box> { + pub async fn connect(self) -> Result> { // println!("Attempting to connect to {}:{}...", self.host, self.port); let connect_timeout = Duration::from_secs(10); @@ -76,11 +81,14 @@ impl Client { Pin::new(&mut stream).connect().await?; // println!("TLS handshake completed"); - let mut lock = self.stream.lock().await; - *lock = Some(stream); + // let mut lock = self.stream.lock().await; + // *lock = Some(stream); // println!("Connection fully established"); - Ok(()) + Ok(ConnectedClient { + connection_params: self, + stream: Arc::new(Mutex::new(stream)), + }) } async fn do_http_connect(&self, stream: &TcpStream) -> Result<(), Box> { @@ -127,16 +135,15 @@ impl Client { } } +} + +impl ConnectedClient { pub async fn send_message(&self, message: &str) -> Result<(), Box> { - let mut lock = self.stream.lock().await; - if let Some(stream) = lock.as_mut() { - let mut pinned = Pin::new(stream); - pinned.as_mut().write_all(message.as_bytes()).await?; - pinned.as_mut().flush().await?; - Ok(()) - } else { - Err("Not connected".into()) - } + let stream = self.stream.lock().await; + let mut pinned = Pin::new(stream); + pinned.as_mut().write_all(message.as_bytes()).await?; + pinned.as_mut().flush().await?; + Ok(()) } pub async fn receive_message(&self, expect_empty: bool, mut should_continue: F) -> Result>>, Box> @@ -148,44 +155,39 @@ impl Client { tokio::spawn(async move { loop { - let mut lock = stream_clone.lock().await; - if let Some(stream) = lock.as_mut() { - let mut response = String::new(); - loop { - let mut buf = [0; 1024]; - match stream.read(&mut buf).await { - Ok(0) => { - let _ = tx.send(Ok(String::new())).await; - return; - } - Ok(n) => { - response.push_str(&String::from_utf8_lossy(&buf[..n])); - if response.ends_with('\n') { - break; - } - } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; - return; + let mut stream = stream_clone.lock().await; + let mut response = String::new(); + loop { + let mut buf = [0; 1024]; + match stream.read(&mut buf).await { + Ok(0) => { + let _ = tx.send(Ok(String::new())).await; + return; + } + Ok(n) => { + response.push_str(&String::from_utf8_lossy(&buf[..n])); + if response.ends_with('\n') { + break; } } + Err(e) => { + let _ = tx.send(Err(Box::new(e) as Box)).await; + return; + } } - let response = response.trim().to_string(); - - if expect_empty && !response.is_empty() { - let _ = tx.send(Err(Box::new(CustomError(format!("Expected empty string, got: {:?}", response))) as Box)).await; - return; - } - - let _ = tx.send(Ok(response.clone())).await; - - if !should_continue(&response) { - break; - } - } else { - let _ = tx.send(Err(Box::new(CustomError("Not connected".to_string())) as Box)).await; + } + let response = response.trim().to_string(); + + if expect_empty && !response.is_empty() { + let _ = tx.send(Err(Box::new(CustomError(format!("Expected empty string, got: {:?}", response))) as Box)).await; return; } + + let _ = tx.send(Ok(response.clone())).await; + + if !should_continue(&response) { + break; + } } }); @@ -240,6 +242,7 @@ use pyo3::prelude::*; #[cfg(feature = "python")] #[pymodule] fn srpc_client(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 9c867db2..179c0e70 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,4 +1,4 @@ -use crate::Client; +use crate::{ClientConfig, ConnectedClient}; use pyo3::prelude::*; use pyo3::exceptions::PyRuntimeError; use pyo3_asyncio; @@ -7,22 +7,30 @@ use std::sync::Arc; use tokio::sync::Mutex; #[pyclass] -pub struct SrpcClient(Arc>); +pub struct SrpcClientConfig(ClientConfig); + +#[pyclass] +pub struct ConnectedSrpcClient(Arc>); #[pymethods] -impl SrpcClient { +impl SrpcClientConfig { #[new] pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { - SrpcClient(Arc::new(Mutex::new(Client::new(host, port, path, cert, key)))) + SrpcClientConfig(ClientConfig::new(host, port, path, cert, key)) } pub fn connect<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { - client.lock().await.connect().await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + pyo3_asyncio::tokio::future_into_py(py, async { + client.connect().await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + .map(|c| ConnectedSrpcClient(Arc::new(Mutex::new(c)))) }) } +} + +#[pymethods] +impl ConnectedSrpcClient { pub fn send_message<'p>(&self, py: Python<'p>, message: String) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi new file mode 100644 index 00000000..100bc55d --- /dev/null +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -0,0 +1,12 @@ +from typing import List + + +class SrpcClientConfig: + def __init__(self, host: str, port: int, cert: str, key: str) -> None: ... + async def connect(self) -> "ConnectedSrpcClient": ... + +class ConnectedSrpcClient: + async def send_message(self, message: str) -> None: ... + async def receive_message(self, expect_empty: bool) -> List[str]: ... + async def send_json(self, payload: str) -> None: ... + async def receive_json(self) -> List[str]: ... From 84a51dba94be359f7ec0c9ec3c9cf21425fda581 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:23:59 -0800 Subject: [PATCH 02/27] feat: add tracing for logs --- rust/lib/srpc/client/Cargo.toml | 1 + rust/lib/srpc/client/src/lib.rs | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 17ff717b..a3f54795 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -12,6 +12,7 @@ tokio = { version = "1.0", features = ["full"] } openssl = "0.10" serde_json = "1.0" tokio-openssl = "0.6" +tracing = "0.1" [dependencies.pyo3] version = "0.18" diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 69d8dfc7..b78c6965 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -9,6 +9,7 @@ use tokio::time::{timeout, Duration}; use std::sync::Arc; use tokio::sync::{Mutex, mpsc}; use std::pin::Pin; +use tracing::debug; // Custom error type #[derive(Debug)] @@ -49,7 +50,7 @@ impl ClientConfig { } pub async fn connect(self) -> Result> { - // println!("Attempting to connect to {}:{}...", self.host, self.port); + debug!("Attempting to connect to {}:{}...", self.host, self.port); let connect_timeout = Duration::from_secs(10); let tcp_stream = match timeout(connect_timeout, @@ -59,13 +60,13 @@ impl ClientConfig { Ok(Err(e)) => return Err(format!("Failed to connect: {}", e).into()), Err(_) => return Err("Connection attempt timed out".into()), }; - // println!("TCP connection established"); + debug!("TCP connection established"); - // println!("Performing HTTP CONNECT..."); + debug!("Performing HTTP CONNECT..."); self.do_http_connect(&tcp_stream).await?; - // println!("HTTP CONNECT successful"); + debug!("HTTP CONNECT successful"); - // println!("Starting TLS handshake..."); + debug!("Starting TLS handshake..."); let mut connector = SslConnector::builder(SslMethod::tls())?; connector.set_verify(SslVerifyMode::NONE); @@ -77,13 +78,13 @@ impl ClientConfig { let ssl = Ssl::new(connector.build().context())?; let mut stream = SslStream::new(ssl, tcp_stream)?; - // println!("Performing TLS handshake..."); + debug!("Performing TLS handshake..."); Pin::new(&mut stream).connect().await?; - // println!("TLS handshake completed"); + debug!("TLS handshake completed"); // let mut lock = self.stream.lock().await; // *lock = Some(stream); - // println!("Connection fully established"); + debug!("Connection fully established"); Ok(ConnectedClient { connection_params: self, @@ -93,9 +94,9 @@ impl ClientConfig { async fn do_http_connect(&self, stream: &TcpStream) -> Result<(), Box> { let connect_request = format!("CONNECT {} HTTP/1.0\r\n\r\n", self.path); - // println!("Sending HTTP CONNECT request: {:?}", connect_request); + debug!("Sending HTTP CONNECT request: {:?}", connect_request); stream.try_write(connect_request.as_bytes())?; - // println!("HTTP CONNECT request sent"); + debug!("HTTP CONNECT request sent"); let read_timeout = Duration::from_secs(10); let start_time = std::time::Instant::now(); @@ -126,9 +127,9 @@ impl ClientConfig { } let response = String::from_utf8_lossy(&buffer); - // println!("Received HTTP CONNECT response: {:?}", response); + debug!("Received HTTP CONNECT response: {:?}", response); if response.starts_with("HTTP/1.0 200") || response.starts_with("HTTP/1.1 200") { - // println!("HTTP CONNECT completed successfully"); + debug!("HTTP CONNECT completed successfully"); Ok(()) } else { Err(format!("Unexpected HTTP response: {}", response).into()) From 53f92951885b1491f6e9374069eaf1a5434fbe85 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:29:28 -0800 Subject: [PATCH 03/27] feat: Add tracing to example --- rust/lib/srpc/client/Cargo.toml | 1 + .../client/examples/rust_client_example.rs | 27 ++++++++++++------- rust/lib/srpc/client/src/lib.rs | 8 ++++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index a3f54795..5c925e16 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -13,6 +13,7 @@ openssl = "0.10" serde_json = "1.0" tokio-openssl = "0.6" tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } [dependencies.pyo3] version = "0.18" diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 97a4a462..0cf6dacc 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,9 +1,18 @@ use srpc_client::ClientConfig; use tokio; use serde_json::json; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy()) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + // Create a new Client instance let client = ClientConfig::new( "", @@ -15,20 +24,20 @@ async fn main() -> Result<(), Box> { // Connect to the server let client = client.connect().await?; - println!("Connected to server"); + info!("Connected to server"); // Send a message let message = "Hypervisor.ProbeVmPort\n"; - println!("Sending message: {:?}", message); + info!("Sending message: {:?}", message); client.send_message(message).await?; // Receive an empty response - println!("Waiting for empty string response..."); + info!("Waiting for empty string response..."); let mut rx = client.receive_message(true, |_| false).await?; while let Some(result) = rx.recv().await { match result { - Ok(response) => println!("Received response: {:?}", response), - Err(e) => eprintln!("Error receiving message: {:?}", e), + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), } } @@ -38,16 +47,16 @@ async fn main() -> Result<(), Box> { "PortNumber": 22 }); - println!("Sending JSON payload: {:?}", json_payload); + info!("Sending JSON payload: {:?}", json_payload); client.send_json(&json_payload).await?; // Receive and parse JSON response - println!("Waiting for JSON response..."); + info!("Waiting for JSON response..."); let mut rx = client.receive_json(|_| false).await?; while let Some(result) = rx.recv().await { match result { - Ok(json_response) => println!("Received JSON response: {:?}", json_response), - Err(e) => eprintln!("Error receiving JSON: {:?}", e), + Ok(json_response) => info!("Received JSON response: {:?}", json_response), + Err(e) => error!("Error receiving JSON: {:?}", e), } } diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index b78c6965..631ad4f3 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -243,6 +243,14 @@ use pyo3::prelude::*; #[cfg(feature = "python")] #[pymodule] fn srpc_client(_py: Python, m: &PyModule) -> PyResult<()> { + use tracing::level_filters::LevelFilter; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy()) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + m.add_class::()?; m.add_class::()?; Ok(()) From a08c5098831080fbeafeef3863d993b3333c8a98 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:30:04 -0800 Subject: [PATCH 04/27] style: remove dead code --- rust/lib/srpc/client/src/lib.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 631ad4f3..01a4b4ce 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -81,11 +81,9 @@ impl ClientConfig { debug!("Performing TLS handshake..."); Pin::new(&mut stream).connect().await?; debug!("TLS handshake completed"); - - // let mut lock = self.stream.lock().await; - // *lock = Some(stream); + debug!("Connection fully established"); - + Ok(ConnectedClient { connection_params: self, stream: Arc::new(Mutex::new(stream)), From df2ac14afb3d6f6c29a901a427802a1331847284 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:41:17 -0800 Subject: [PATCH 05/27] style: Remove unused imports --- rust/lib/srpc/client/examples/rust_client_example.rs | 1 - rust/lib/srpc/client/src/python_bindings.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 0cf6dacc..63c5821e 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,5 +1,4 @@ use srpc_client::ClientConfig; -use tokio; use serde_json::json; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 179c0e70..9af082b8 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,7 +1,6 @@ use crate::{ClientConfig, ConnectedClient}; use pyo3::prelude::*; use pyo3::exceptions::PyRuntimeError; -use pyo3_asyncio; use serde_json::Value; use std::sync::Arc; use tokio::sync::Mutex; From b2dbb21f592b5dbec512e9aa5fee336e2d4171ae Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:41:38 -0800 Subject: [PATCH 06/27] style: run cargo fmt --- .../client/examples/rust_client_example.rs | 4 +- rust/lib/srpc/client/src/lib.rs | 87 +++++++++++-------- rust/lib/srpc/client/src/python_bindings.rs | 41 ++++++--- 3 files changed, 82 insertions(+), 50 deletions(-) diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 63c5821e..c2553491 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,5 +1,5 @@ -use srpc_client::ClientConfig; use serde_json::json; +use srpc_client::ClientConfig; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -18,7 +18,7 @@ async fn main() -> Result<(), Box> { 6976, "/_SRPC_/TLS/JSON", "", - "" + "", ); // Connect to the server diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 01a4b4ce..862cc705 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -1,14 +1,14 @@ -use tokio::net::TcpStream; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; +use serde_json::Value; use std::error::Error; use std::fmt; -use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode, Ssl}; -use serde_json::Value; -use tokio_openssl::SslStream; -use tokio::time::{timeout, Duration}; -use std::sync::Arc; -use tokio::sync::{Mutex, mpsc}; use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, Mutex}; +use tokio::time::{timeout, Duration}; +use tokio_openssl::SslStream; use tracing::debug; // Custom error type @@ -51,21 +51,24 @@ impl ClientConfig { pub async fn connect(self) -> Result> { debug!("Attempting to connect to {}:{}...", self.host, self.port); - + let connect_timeout = Duration::from_secs(10); - let tcp_stream = match timeout(connect_timeout, - TcpStream::connect(format!("{}:{}", self.host, self.port)) - ).await { + let tcp_stream = match timeout( + connect_timeout, + TcpStream::connect(format!("{}:{}", self.host, self.port)), + ) + .await + { Ok(Ok(stream)) => stream, Ok(Err(e)) => return Err(format!("Failed to connect: {}", e).into()), Err(_) => return Err("Connection attempt timed out".into()), }; debug!("TCP connection established"); - + debug!("Performing HTTP CONNECT..."); self.do_http_connect(&tcp_stream).await?; debug!("HTTP CONNECT successful"); - + debug!("Starting TLS handshake..."); let mut connector = SslConnector::builder(SslMethod::tls())?; connector.set_verify(SslVerifyMode::NONE); @@ -74,10 +77,10 @@ impl ClientConfig { connector.set_certificate_file(&self.cert, openssl::ssl::SslFiletype::PEM)?; connector.set_private_key_file(&self.key, openssl::ssl::SslFiletype::PEM)?; } - + let ssl = Ssl::new(connector.build().context())?; let mut stream = SslStream::new(ssl, tcp_stream)?; - + debug!("Performing TLS handshake..."); Pin::new(&mut stream).connect().await?; debug!("TLS handshake completed"); @@ -95,11 +98,11 @@ impl ClientConfig { debug!("Sending HTTP CONNECT request: {:?}", connect_request); stream.try_write(connect_request.as_bytes())?; debug!("HTTP CONNECT request sent"); - + let read_timeout = Duration::from_secs(10); let start_time = std::time::Instant::now(); let mut buffer = Vec::new(); - + while start_time.elapsed() < read_timeout { match stream.try_read_buf(&mut buffer) { Ok(0) => { @@ -119,11 +122,11 @@ impl ClientConfig { Err(e) => return Err(format!("Error reading HTTP CONNECT response: {}", e).into()), } } - + if buffer.is_empty() { return Err("Timeout while waiting for HTTP CONNECT response".into()); } - + let response = String::from_utf8_lossy(&buffer); debug!("Received HTTP CONNECT response: {:?}", response); if response.starts_with("HTTP/1.0 200") || response.starts_with("HTTP/1.1 200") { @@ -133,7 +136,6 @@ impl ClientConfig { Err(format!("Unexpected HTTP response: {}", response).into()) } } - } impl ConnectedClient { @@ -145,7 +147,11 @@ impl ConnectedClient { Ok(()) } - pub async fn receive_message(&self, expect_empty: bool, mut should_continue: F) -> Result>>, Box> + pub async fn receive_message( + &self, + expect_empty: bool, + mut should_continue: F, + ) -> Result>>, Box> where F: FnMut(&str) -> bool + Send + 'static, { @@ -176,14 +182,19 @@ impl ConnectedClient { } } let response = response.trim().to_string(); - + if expect_empty && !response.is_empty() { - let _ = tx.send(Err(Box::new(CustomError(format!("Expected empty string, got: {:?}", response))) as Box)).await; + let _ = tx + .send(Err(Box::new(CustomError(format!( + "Expected empty string, got: {:?}", + response + ))) as Box)) + .await; return; } - + let _ = tx.send(Ok(response.clone())).await; - + if !should_continue(&response) { break; } @@ -198,7 +209,10 @@ impl ConnectedClient { self.send_message(&json_string).await } - pub async fn receive_json(&self, should_continue: F) -> Result>>, Box> + pub async fn receive_json( + &self, + should_continue: F, + ) -> Result>>, Box> where F: FnMut(&str) -> bool + Send + 'static, { @@ -208,18 +222,16 @@ impl ConnectedClient { tokio::spawn(async move { while let Some(result) = rx.recv().await { match result { - Ok(json_str) => { - match serde_json::from_str(&json_str) { - Ok(json_value) => { - if let Err(_) = tx.send(Ok(json_value)).await { - break; - } - } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; + Ok(json_str) => match serde_json::from_str(&json_str) { + Ok(json_value) => { + if let Err(_) = tx.send(Ok(json_value)).await { + break; } } - } + Err(e) => { + let _ = tx.send(Err(Box::new(e) as Box)).await; + } + }, Err(e) => { let _ = tx.send(Err(e)).await; } @@ -229,7 +241,6 @@ impl ConnectedClient { Ok(new_rx) } - } #[cfg(feature = "python")] diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 9af082b8..aaf78379 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,6 +1,6 @@ use crate::{ClientConfig, ConnectedClient}; -use pyo3::prelude::*; use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; use serde_json::Value; use std::sync::Arc; use tokio::sync::Mutex; @@ -21,11 +21,13 @@ impl SrpcClientConfig { pub fn connect<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async { - client.connect().await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + client + .connect() + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) .map(|c| ConnectedSrpcClient(Arc::new(Mutex::new(c)))) }) } - } #[pymethods] @@ -33,16 +35,25 @@ impl ConnectedSrpcClient { pub fn send_message<'p>(&self, py: Python<'p>, message: String) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - client.lock().await.send_message(&message).await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + client + .lock() + .await + .send_message(&message) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } pub fn receive_message<'p>(&self, py: Python<'p>, expect_empty: bool) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client.lock().await.receive_message(expect_empty, |_| false).await + let mut rx = client + .lock() + .await + .receive_message(expect_empty, |_| false) + .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - + let mut results = Vec::new(); while let Some(result) = rx.recv().await { match result { @@ -57,17 +68,27 @@ impl ConnectedSrpcClient { pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - let value: Value = serde_json::from_str(&payload).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - client.lock().await.send_json(&value).await.map_err(|e| PyRuntimeError::new_err(e.to_string())) + let value: Value = serde_json::from_str(&payload) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + client + .lock() + .await + .send_json(&value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } pub fn receive_json<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client.lock().await.receive_json(|_| false).await + let mut rx = client + .lock() + .await + .receive_json(|_| false) + .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - + let mut results = Vec::new(); while let Some(result) = rx.recv().await { match result { From 613c798ac38a40cfd7e2dfea3adb50cd552ee2f4 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 14:42:03 -0800 Subject: [PATCH 07/27] style: run black --- .../srpc/client/examples/python_client_example.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/rust/lib/srpc/client/examples/python_client_example.py b/rust/lib/srpc/client/examples/python_client_example.py index cb9acdb2..145b83b4 100644 --- a/rust/lib/srpc/client/examples/python_client_example.py +++ b/rust/lib/srpc/client/examples/python_client_example.py @@ -11,8 +11,15 @@ import json from srpc_client import SrpcClientConfig + async def main(): - client = SrpcClientConfig("", 6976, "/_SRPC_/TLS/JSON", "", "") + client = SrpcClientConfig( + "", + 6976, + "/_SRPC_/TLS/JSON", + "", + "", + ) client = await client.connect() print("Connected to server") @@ -25,9 +32,7 @@ async def main(): for response in responses: print(f"Received response: {response}") - json_payload = { - "IpAddress": "" - } + json_payload = {"IpAddress": ""} json_string = json.dumps(json_payload) await client.send_json(json_string) print(f"Sent JSON payload: {json_payload}") @@ -36,5 +41,6 @@ async def main(): for json_response in json_responses: print(f"Received JSON response: {json.loads(json_response)}") + if __name__ == "__main__": asyncio.run(main()) From cd2f909c8a56bebc2c98ae1cae050df3319075b8 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Wed, 11 Dec 2024 22:19:39 -0800 Subject: [PATCH 08/27] refactor: Reduce 1 level of nesting with `and_then` --- rust/lib/srpc/client/src/lib.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 862cc705..73d9706c 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -221,17 +221,15 @@ impl ConnectedClient { tokio::spawn(async move { while let Some(result) = rx.recv().await { - match result { - Ok(json_str) => match serde_json::from_str(&json_str) { - Ok(json_value) => { - if let Err(_) = tx.send(Ok(json_value)).await { - break; - } - } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; + match result.and_then(|json_str| { + serde_json::from_str(&json_str) + .map_err(|e| Box::new(e) as Box) + }) { + Ok(json_value) => { + if let Err(_) = tx.send(Ok(json_value)).await { + break; } - }, + } Err(e) => { let _ = tx.send(Err(e)).await; } From bd78a4c576eef55c5c911f9793192d9028ab2f75 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Thu, 12 Dec 2024 18:46:31 -0800 Subject: [PATCH 09/27] chore: add log line when processing buffer --- rust/lib/srpc/client/src/lib.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 73d9706c..41b77912 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -170,7 +170,9 @@ impl ConnectedClient { return; } Ok(n) => { - response.push_str(&String::from_utf8_lossy(&buf[..n])); + let res = String::from_utf8_lossy(&buf[..n]); + response.push_str(&res); + debug!("ResponseT: {:?}", res); if response.ends_with('\n') { break; } From b6df1e42fdc63af73b2c6f9fda772b941ee2d05b Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Thu, 12 Dec 2024 18:49:24 -0800 Subject: [PATCH 10/27] fix: @rgooch says this shouldnt happen and is handle in next case --- rust/lib/srpc/client/src/lib.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 41b77912..7bcfcf88 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -165,10 +165,6 @@ impl ConnectedClient { loop { let mut buf = [0; 1024]; match stream.read(&mut buf).await { - Ok(0) => { - let _ = tx.send(Ok(String::new())).await; - return; - } Ok(n) => { let res = String::from_utf8_lossy(&buf[..n]); response.push_str(&res); From 386166ebeb5a76a95a90ec02d4352883801d254b Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Thu, 12 Dec 2024 19:40:33 -0800 Subject: [PATCH 11/27] feat: Allow streaming responses to python --- rust/lib/srpc/client/Cargo.toml | 1 + rust/lib/srpc/client/src/lib.rs | 6 +- rust/lib/srpc/client/src/python_bindings.rs | 152 ++++++++++++++++---- 3 files changed, 133 insertions(+), 26 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 5c925e16..b16e6a1b 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -8,6 +8,7 @@ name = "srpc_client" crate-type = ["cdylib", "rlib"] [dependencies] +futures = "0.3.28" tokio = { version = "1.0", features = ["full"] } openssl = "0.10" serde_json = "1.0" diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 7bcfcf88..39acb02f 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -252,7 +252,11 @@ fn srpc_client(_py: Python, m: &PyModule) -> PyResult<()> { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; tracing_subscriber::registry() - .with(tracing_subscriber::EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy()) + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) .with(tracing_subscriber::fmt::Layer::default().compact()) .init(); diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index aaf78379..fa4daa62 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,9 +1,14 @@ use crate::{ClientConfig, ConnectedClient}; -use pyo3::exceptions::PyRuntimeError; +use futures::{Stream, StreamExt}; +use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; use serde_json::Value; -use std::sync::Arc; -use tokio::sync::Mutex; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::sync::{mpsc, Mutex}; #[pyclass] pub struct SrpcClientConfig(ClientConfig); @@ -30,6 +35,110 @@ impl SrpcClientConfig { } } +struct Streamer { + rx: mpsc::Receiver>>, +} + +impl Streamer { + fn new(rx: mpsc::Receiver>>) -> Self { + Streamer { rx } + } +} + +impl Stream for Streamer { + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +#[pyo3::pyclass] +struct PyStream { + pub streamer: Arc>, +} + +impl PyStream { + fn new(streamer: Streamer) -> Self { + PyStream { + streamer: Arc::new(Mutex::new(streamer)), + } + } +} + +#[pymethods] +impl PyStream { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__(&self, py: Python) -> PyResult> { + let streamer = self.streamer.clone(); + let future = pyo3_asyncio::tokio::future_into_py(py, async move { + let val = streamer.lock().await.next().await; + match val { + Some(Ok(val)) => Ok(val), + Some(Err(val)) => Err(PyRuntimeError::new_err(val.to_string())), + None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), + } + }); + Ok(Some(future?.into())) + } +} + +struct ValueStreamer { + rx: mpsc::Receiver>>, +} + +impl ValueStreamer { + fn new( + rx: mpsc::Receiver>>, + ) -> Self { + ValueStreamer { rx } + } +} + +impl Stream for ValueStreamer { + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.get_mut().rx.poll_recv(cx) + } +} + +#[pyo3::pyclass] +struct PyValueStream { + pub streamer: Arc>, +} + +impl PyValueStream { + fn new(streamer: ValueStreamer) -> Self { + PyValueStream { + streamer: Arc::new(Mutex::new(streamer)), + } + } +} + +#[pymethods] +impl PyValueStream { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __anext__(&self, py: Python) -> PyResult> { + let streamer = self.streamer.clone(); + let future = pyo3_asyncio::tokio::future_into_py(py, async move { + let val = streamer.lock().await.next().await; + match val { + Some(Ok(val)) => Ok(val.to_string()), + Some(Err(val)) => Err(PyRuntimeError::new_err(val.to_string())), + None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), + } + }); + Ok(Some(future?.into())) + } +} + #[pymethods] impl ConnectedSrpcClient { pub fn send_message<'p>(&self, py: Python<'p>, message: String) -> PyResult<&'p PyAny> { @@ -44,24 +153,22 @@ impl ConnectedSrpcClient { }) } - pub fn receive_message<'p>(&self, py: Python<'p>, expect_empty: bool) -> PyResult<&'p PyAny> { + pub fn receive_message<'p>( + &self, + py: Python<'p>, + expect_empty: bool, + should_continue: bool, + ) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client + let rx = client .lock() .await - .receive_message(expect_empty, |_| false) + .receive_message(expect_empty, move |_| should_continue) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - let mut results = Vec::new(); - while let Some(result) = rx.recv().await { - match result { - Ok(message) => results.push(message), - Err(e) => return Err(PyRuntimeError::new_err(e.to_string())), - } - } - Ok(Python::with_gil(|py| results.to_object(py))) + Ok(Python::with_gil(|_py| PyStream::new(Streamer::new(rx)))) }) } @@ -79,24 +186,19 @@ impl ConnectedSrpcClient { }) } - pub fn receive_json<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn receive_json<'p>(&self, py: Python<'p>, should_continue: bool) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { - let mut rx = client + let rx = client .lock() .await - .receive_json(|_| false) + .receive_json(move |_| should_continue) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - let mut results = Vec::new(); - while let Some(result) = rx.recv().await { - match result { - Ok(json_value) => results.push(json_value.to_string()), - Err(e) => return Err(PyRuntimeError::new_err(e.to_string())), - } - } - Ok(Python::with_gil(|py| results.to_object(py))) + Ok(Python::with_gil(|_py| { + PyValueStream::new(ValueStreamer::new(rx)) + })) }) } } From 7e7147aff954a888e38618567f9d7c9749da6076 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Thu, 12 Dec 2024 19:45:03 -0800 Subject: [PATCH 12/27] chore: Examples for streaming responses --- rust/lib/srpc/client/Cargo.toml | 5 ++ .../client/examples/python_client_example2.py | 50 ++++++++++++++++ .../client/examples/rust_client_example2.rs | 57 +++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 rust/lib/srpc/client/examples/python_client_example2.py create mode 100644 rust/lib/srpc/client/examples/rust_client_example2.rs diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index b16e6a1b..aaac6912 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -35,5 +35,10 @@ name = "rust_client_example" path = "examples/rust_client_example.rs" required-features = [] +[[example]] +name = "rust_client_example2" +path = "examples/rust_client_example2.rs" +required-features = [] + [dev-dependencies] tokio = { version = "1.0", features = ["full", "macros"] } diff --git a/rust/lib/srpc/client/examples/python_client_example2.py b/rust/lib/srpc/client/examples/python_client_example2.py new file mode 100644 index 00000000..7dcf3f66 --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example2.py @@ -0,0 +1,50 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build the Rust library: maturin build --features python +2. Install the wheel: pip install target/wheels/srpc_client-*.whl +3. Run this script: python examples/python_client_example.py +""" + +import asyncio +import json +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + "", + 6976, + "/_SRPC_/TLS/JSON", + "", + "", + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + # Send a message + message = "Hypervisor.GetUpdates\n" + print(f"Sending message: {message}") + await client.send_message(message) + print(f"Sent message: {message}") + + # Receive an empty response + print("Waiting for empty string response...") + responses = await client.receive_message(expect_empty=True) + async for response in responses: + print(f"Received response: {response}") + + # Receive responses + responses = await client.receive_json() + async for response in responses: + print(f"Received response: {json.loads(response)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/rust_client_example2.rs b/rust/lib/srpc/client/examples/rust_client_example2.rs new file mode 100644 index 00000000..4fccebd8 --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example2.rs @@ -0,0 +1,57 @@ +use srpc_client::ClientConfig; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + "", + 6976, + "/_SRPC_/TLS/JSON", + "", + "", + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + // Send a message + let message = "Hypervisor.GetUpdates\n"; + info!("Sending message: {:?}", message); + client.send_message(message).await?; + info!("Sent message: {:?}", message); + + // Receive an empty response + info!("Waiting for empty string response..."); + let mut rx = client.receive_message(true, |_| false).await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + // Receive responses + let mut rx = client.receive_json(|_| false).await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + Ok(()) +} From 7739f583b8a8e1f892f3cb3d7507b15de2d23f8b Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Thu, 12 Dec 2024 20:17:04 -0800 Subject: [PATCH 13/27] feat: Allow passing a callback from python --- .../client/examples/python_client_example2.py | 4 +- .../client/examples/rust_client_example.rs | 6 ++- rust/lib/srpc/client/src/python_bindings.rs | 53 +++++++++++++++++++ rust/lib/srpc/client/srpc_client.pyi | 8 +-- 4 files changed, 65 insertions(+), 6 deletions(-) diff --git a/rust/lib/srpc/client/examples/python_client_example2.py b/rust/lib/srpc/client/examples/python_client_example2.py index 7dcf3f66..480bebe6 100644 --- a/rust/lib/srpc/client/examples/python_client_example2.py +++ b/rust/lib/srpc/client/examples/python_client_example2.py @@ -36,12 +36,12 @@ async def main(): # Receive an empty response print("Waiting for empty string response...") - responses = await client.receive_message(expect_empty=True) + responses = await client.receive_message(expect_empty=True, should_continue=False) async for response in responses: print(f"Received response: {response}") # Receive responses - responses = await client.receive_json() + responses = await client.receive_json_cb(should_continue=lambda _: True) async for response in responses: print(f"Received response: {json.loads(response)}") diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index c2553491..55fb8ba1 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -6,7 +6,11 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::registry() - .with(tracing_subscriber::EnvFilter::builder().with_default_directive(LevelFilter::INFO.into()).from_env_lossy()) + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) .with(tracing_subscriber::fmt::Layer::default().compact()) .init(); diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index fa4daa62..c8b612fb 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -172,6 +172,34 @@ impl ConnectedSrpcClient { }) } + pub fn receive_message_cb<'p>( + &self, + py: Python<'p>, + expect_empty: bool, + should_continue: &PyAny, + ) -> PyResult<&'p PyAny> { + let client = self.0.clone(); + let should_continue = should_continue.to_object(py); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let should_continue = move |response: &str| -> bool { + Python::with_gil(|py| { + let func = should_continue.as_ref(py); + + func.call1((response,)).and_then(|v| v.extract::()).unwrap_or(false) + }) + }; + let rx = client + .lock() + .await + .receive_message(expect_empty, should_continue) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Python::with_gil(|_py| PyStream::new(Streamer::new(rx)))) + }) + } + pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult<&'p PyAny> { let client = self.0.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { @@ -201,4 +229,29 @@ impl ConnectedSrpcClient { })) }) } + + pub fn receive_json_cb<'p>(&self, py: Python<'p>, should_continue: &PyAny) -> PyResult<&'p PyAny> { + let client = self.0.clone(); + let should_continue = should_continue.to_object(py); + + pyo3_asyncio::tokio::future_into_py(py, async move { + let should_continue = move |response: &str| -> bool { + Python::with_gil(|py| { + let func = should_continue.as_ref(py); + + func.call1((response,)).and_then(|v| v.extract::()).unwrap_or(false) + }) + }; + let rx = client + .lock() + .await + .receive_json(should_continue) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Python::with_gil(|_py| { + PyValueStream::new(ValueStreamer::new(rx)) + })) + }) + } } diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi index 100bc55d..28635435 100644 --- a/rust/lib/srpc/client/srpc_client.pyi +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -1,4 +1,4 @@ -from typing import List +from typing import Callable, List class SrpcClientConfig: @@ -7,6 +7,8 @@ class SrpcClientConfig: class ConnectedSrpcClient: async def send_message(self, message: str) -> None: ... - async def receive_message(self, expect_empty: bool) -> List[str]: ... + async def receive_message(self, expect_empty: bool, should_continue: bool) -> List[str]: ... + async def receive_message_cb(self, expect_empty: bool, should_continue: Callable[[str], bool]) -> List[str]: ... async def send_json(self, payload: str) -> None: ... - async def receive_json(self) -> List[str]: ... + async def receive_json(self, should_continue: bool) -> List[str]: ... + async def receive_json_cb(self, should_continue: Callable[[str], bool]) -> List[str]: ... From 90dd80baea49dd108f3ddf9c6341c91424dd1116 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Fri, 13 Dec 2024 17:06:28 -0800 Subject: [PATCH 14/27] feat: switch to a more idiomatic FrameRead *LineCodec Also add ReceiveOptions for future testing --- rust/lib/srpc/client/Cargo.toml | 2 + .../client/examples/rust_client_example.rs | 10 ++- .../client/examples/rust_client_example2.rs | 10 ++- rust/lib/srpc/client/src/chunk_limiter.rs | 39 ++++++++ rust/lib/srpc/client/src/lib.rs | 89 ++++++++++++++++++- rust/lib/srpc/client/src/python_bindings.rs | 28 ++++-- 6 files changed, 160 insertions(+), 18 deletions(-) create mode 100644 rust/lib/srpc/client/src/chunk_limiter.rs diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index aaac6912..9be16f08 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -8,11 +8,13 @@ name = "srpc_client" crate-type = ["cdylib", "rlib"] [dependencies] +bytes = "1.0" futures = "0.3.28" tokio = { version = "1.0", features = ["full"] } openssl = "0.10" serde_json = "1.0" tokio-openssl = "0.6" +tokio-util = { version = "0.7.13", features = ["codec"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 55fb8ba1..f176c15c 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,5 +1,5 @@ use serde_json::json; -use srpc_client::ClientConfig; +use srpc_client::{ClientConfig, ReceiveOptions}; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -36,7 +36,9 @@ async fn main() -> Result<(), Box> { // Receive an empty response info!("Waiting for empty string response..."); - let mut rx = client.receive_message(true, |_| false).await?; + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { Ok(response) => info!("Received response: {:?}", response), @@ -55,7 +57,9 @@ async fn main() -> Result<(), Box> { // Receive and parse JSON response info!("Waiting for JSON response..."); - let mut rx = client.receive_json(|_| false).await?; + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { Ok(json_response) => info!("Received JSON response: {:?}", json_response), diff --git a/rust/lib/srpc/client/examples/rust_client_example2.rs b/rust/lib/srpc/client/examples/rust_client_example2.rs index 4fccebd8..85d287f9 100644 --- a/rust/lib/srpc/client/examples/rust_client_example2.rs +++ b/rust/lib/srpc/client/examples/rust_client_example2.rs @@ -1,4 +1,4 @@ -use srpc_client::ClientConfig; +use srpc_client::{ClientConfig, ReceiveOptions}; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -36,7 +36,9 @@ async fn main() -> Result<(), Box> { // Receive an empty response info!("Waiting for empty string response..."); - let mut rx = client.receive_message(true, |_| false).await?; + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { Ok(response) => info!("Received response: {:?}", response), @@ -45,7 +47,9 @@ async fn main() -> Result<(), Box> { } // Receive responses - let mut rx = client.receive_json(|_| false).await?; + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; while let Some(result) = rx.recv().await { match result { Ok(response) => info!("Received response: {:?}", response), diff --git a/rust/lib/srpc/client/src/chunk_limiter.rs b/rust/lib/srpc/client/src/chunk_limiter.rs new file mode 100644 index 00000000..d03a6527 --- /dev/null +++ b/rust/lib/srpc/client/src/chunk_limiter.rs @@ -0,0 +1,39 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tracing::trace; + +pub struct ChunkLimiter { + inner: R, + max_chunk_size: usize, +} + +impl ChunkLimiter { + pub fn new(inner: R, max_chunk_size: usize) -> Self { + Self { + inner, + max_chunk_size, + } + } +} + +impl AsyncRead for ChunkLimiter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let limit = self.max_chunk_size.min(buf.remaining()); + + let available = buf.initialize_unfilled_to(limit); + let mut limited_buf = ReadBuf::new(available); + + let poll_result = Pin::new(&mut self.inner).poll_read(cx, &mut limited_buf); + + let filled_len = limited_buf.filled().len(); + trace!("Read {} bytes", filled_len); + buf.advance(filled_len); + + poll_result + } +} diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 39acb02f..3e516e43 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -1,16 +1,21 @@ +use chunk_limiter::ChunkLimiter; +use futures::StreamExt; use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; use serde_json::Value; use std::error::Error; use std::fmt; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::sync::{mpsc, Mutex}; use tokio::time::{timeout, Duration}; use tokio_openssl::SslStream; +use tokio_util::codec::{FramedRead, LinesCodec}; use tracing::debug; +mod chunk_limiter; + // Custom error type #[derive(Debug)] struct CustomError(String); @@ -32,6 +37,29 @@ pub struct ClientConfig { key: String, } +pub struct ReceiveOptions { + channel_buffer_size: usize, + max_chunk_size: usize, +} + +impl ReceiveOptions { + pub fn new(channel_buffer_size: usize, max_chunk_size: usize) -> Self { + ReceiveOptions { + channel_buffer_size, + max_chunk_size, + } + } +} + +impl Default for ReceiveOptions { + fn default() -> Self { + ReceiveOptions { + channel_buffer_size: 100, + max_chunk_size: 16384, + } + } +} + pub struct ConnectedClient { pub connection_params: ClientConfig, stream: Connected, @@ -147,7 +175,7 @@ impl ConnectedClient { Ok(()) } - pub async fn receive_message( + pub async fn receive_message_old( &self, expect_empty: bool, mut should_continue: F, @@ -202,6 +230,58 @@ impl ConnectedClient { Ok(rx) } + pub async fn receive_message( + &self, + expect_empty: bool, + mut should_continue: F, + opts: &ReceiveOptions, + ) -> Result>>, Box> + where + F: FnMut(&str) -> bool + Send + 'static, + { + let stream = Arc::clone(&self.stream); + let (tx, rx) = mpsc::channel(opts.channel_buffer_size); + let max_chunk_size = opts.max_chunk_size; + + tokio::spawn(async move { + let mut guard = stream.lock().await; + let limited_reader = ChunkLimiter::new(&mut *guard, max_chunk_size); + let buf_reader = BufReader::new(limited_reader); + let mut framed = FramedRead::new(buf_reader, LinesCodec::new()); + + while let Some(line_res) = framed.next().await { + let line_res = line_res.map_err(|e| Box::new(e) as Box); + + match line_res { + Ok(line) => { + if expect_empty && !line.is_empty() { + let _ = tx + .send(Err(Box::new(CustomError(format!( + "Expected empty line, got: {:?}", + line + ))) + as Box)) + .await; + break; + } + + let _ = tx.send(Ok(line.clone())).await; + + if !should_continue(&line) { + break; + } + } + Err(err) => { + let _ = tx.send(Err(err)).await; + break; + } + } + } + }); + + Ok(rx) + } + pub async fn send_json(&self, payload: &Value) -> Result<(), Box> { let json_string = payload.to_string() + "\n"; self.send_message(&json_string).await @@ -210,12 +290,13 @@ impl ConnectedClient { pub async fn receive_json( &self, should_continue: F, + opts: &ReceiveOptions, ) -> Result>>, Box> where F: FnMut(&str) -> bool + Send + 'static, { - let mut rx = self.receive_message(false, should_continue).await?; - let (tx, new_rx) = mpsc::channel(100); + let mut rx = self.receive_message(false, should_continue, opts).await?; + let (tx, new_rx) = mpsc::channel(opts.channel_buffer_size); tokio::spawn(async move { while let Some(result) = rx.recv().await { diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index c8b612fb..c03e78a5 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,4 +1,4 @@ -use crate::{ClientConfig, ConnectedClient}; +use crate::{ClientConfig, ConnectedClient, ReceiveOptions}; use futures::{Stream, StreamExt}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; @@ -164,7 +164,11 @@ impl ConnectedSrpcClient { let rx = client .lock() .await - .receive_message(expect_empty, move |_| should_continue) + .receive_message( + expect_empty, + move |_| should_continue, + &ReceiveOptions::default(), + ) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -186,13 +190,15 @@ impl ConnectedSrpcClient { Python::with_gil(|py| { let func = should_continue.as_ref(py); - func.call1((response,)).and_then(|v| v.extract::()).unwrap_or(false) + func.call1((response,)) + .and_then(|v| v.extract::()) + .unwrap_or(false) }) }; let rx = client .lock() .await - .receive_message(expect_empty, should_continue) + .receive_message(expect_empty, should_continue, &ReceiveOptions::default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -220,7 +226,7 @@ impl ConnectedSrpcClient { let rx = client .lock() .await - .receive_json(move |_| should_continue) + .receive_json(move |_| should_continue, &ReceiveOptions::default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -230,7 +236,11 @@ impl ConnectedSrpcClient { }) } - pub fn receive_json_cb<'p>(&self, py: Python<'p>, should_continue: &PyAny) -> PyResult<&'p PyAny> { + pub fn receive_json_cb<'p>( + &self, + py: Python<'p>, + should_continue: &PyAny, + ) -> PyResult<&'p PyAny> { let client = self.0.clone(); let should_continue = should_continue.to_object(py); @@ -239,13 +249,15 @@ impl ConnectedSrpcClient { Python::with_gil(|py| { let func = should_continue.as_ref(py); - func.call1((response,)).and_then(|v| v.extract::()).unwrap_or(false) + func.call1((response,)) + .and_then(|v| v.extract::()) + .unwrap_or(false) }) }; let rx = client .lock() .await - .receive_json(should_continue) + .receive_json(should_continue, &ReceiveOptions::default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; From 02560b175a75212bfd41f700a6b8d96c83202f51 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Fri, 13 Dec 2024 17:20:02 -0800 Subject: [PATCH 15/27] chore: Make examples use env vars for flexibility --- .../lib/srpc/client/examples/python_client_example.py | 11 ++++++----- .../srpc/client/examples/python_client_example2.py | 11 ++++++----- rust/lib/srpc/client/examples/rust_client_example.rs | 10 +++++----- rust/lib/srpc/client/examples/rust_client_example2.rs | 10 +++++----- rust/lib/srpc/client/srpc_client.pyi | 2 +- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/rust/lib/srpc/client/examples/python_client_example.py b/rust/lib/srpc/client/examples/python_client_example.py index 145b83b4..1aded322 100644 --- a/rust/lib/srpc/client/examples/python_client_example.py +++ b/rust/lib/srpc/client/examples/python_client_example.py @@ -9,16 +9,17 @@ import asyncio import json +import os from srpc_client import SrpcClientConfig async def main(): client = SrpcClientConfig( - "", - 6976, - "/_SRPC_/TLS/JSON", - "", - "", + os.environ["EXAMPLE_1_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_1_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_1_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_1_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_1_SRPC_SERVER_KEY"], ) client = await client.connect() diff --git a/rust/lib/srpc/client/examples/python_client_example2.py b/rust/lib/srpc/client/examples/python_client_example2.py index 480bebe6..594ba011 100644 --- a/rust/lib/srpc/client/examples/python_client_example2.py +++ b/rust/lib/srpc/client/examples/python_client_example2.py @@ -9,6 +9,7 @@ import asyncio import json +import os from srpc_client import SrpcClientConfig @@ -17,11 +18,11 @@ async def main(): # Create a new ClientConfig instance client = SrpcClientConfig( - "", - 6976, - "/_SRPC_/TLS/JSON", - "", - "", + os.environ["EXAMPLE_2_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_2_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_2_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_2_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_2_SRPC_SERVER_KEY"], ) # Connect to the server diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index f176c15c..88736e85 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -18,11 +18,11 @@ async fn main() -> Result<(), Box> { // Create a new Client instance let client = ClientConfig::new( - "", - 6976, - "/_SRPC_/TLS/JSON", - "", - "", + &std::env::var("EXAMPLE_1_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_1_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_1_SRPC_SERVER_KEY")?, ); // Connect to the server diff --git a/rust/lib/srpc/client/examples/rust_client_example2.rs b/rust/lib/srpc/client/examples/rust_client_example2.rs index 85d287f9..e4b18ca7 100644 --- a/rust/lib/srpc/client/examples/rust_client_example2.rs +++ b/rust/lib/srpc/client/examples/rust_client_example2.rs @@ -17,11 +17,11 @@ async fn main() -> Result<(), Box> { // Create a new ClientConfig instance let config = ClientConfig::new( - "", - 6976, - "/_SRPC_/TLS/JSON", - "", - "", + &std::env::var("EXAMPLE_2_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_2_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_2_SRPC_SERVER_KEY")?, ); // Connect to the server diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi index 28635435..f9b6911e 100644 --- a/rust/lib/srpc/client/srpc_client.pyi +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -2,7 +2,7 @@ from typing import Callable, List class SrpcClientConfig: - def __init__(self, host: str, port: int, cert: str, key: str) -> None: ... + def __init__(self, host: str, port: int, path: str, cert: str, key: str) -> None: ... async def connect(self) -> "ConnectedSrpcClient": ... class ConnectedSrpcClient: From 7f0e6bc1e63d302e4613b895f5a6b9f650ab615f Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Fri, 13 Dec 2024 20:25:13 -0800 Subject: [PATCH 16/27] chore: loosen bounds and bump version for api change --- rust/lib/srpc/client/Cargo.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 9be16f08..0a3fe314 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "srpc_client" -version = "0.1.1" +version = "0.2.0" edition = "2021" [lib] @@ -8,13 +8,13 @@ name = "srpc_client" crate-type = ["cdylib", "rlib"] [dependencies] -bytes = "1.0" -futures = "0.3.28" -tokio = { version = "1.0", features = ["full"] } +bytes = "1" +futures = "0.3" +tokio = { version = "1", features = ["full"] } openssl = "0.10" -serde_json = "1.0" +serde_json = "1" tokio-openssl = "0.6" -tokio-util = { version = "0.7.13", features = ["codec"] } +tokio-util = { version = "0.7", features = ["codec"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } @@ -43,4 +43,4 @@ path = "examples/rust_client_example2.rs" required-features = [] [dev-dependencies] -tokio = { version = "1.0", features = ["full", "macros"] } +tokio = { version = "1", features = ["full", "macros"] } From 4d2f7da058f84d913ab5fa995d3769861b425469 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sat, 14 Dec 2024 12:29:31 -0800 Subject: [PATCH 17/27] chore: Abstract to try to add test seam --- rust/lib/srpc/client/src/lib.rs | 29 ++++++++++++++------- rust/lib/srpc/client/src/python_bindings.rs | 4 ++- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 3e516e43..a20f95e7 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -6,7 +6,7 @@ use std::error::Error; use std::fmt; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::sync::{mpsc, Mutex}; use tokio::time::{timeout, Duration}; @@ -60,12 +60,14 @@ impl Default for ReceiveOptions { } } -pub struct ConnectedClient { +pub struct ConnectedClient +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ pub connection_params: ClientConfig, - stream: Connected, + stream: Arc>, } -type Connected = Arc>>; impl ClientConfig { pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { ClientConfig { @@ -77,7 +79,7 @@ impl ClientConfig { } } - pub async fn connect(self) -> Result> { + pub async fn connect(self) -> Result>, Box> { debug!("Attempting to connect to {}:{}...", self.host, self.port); let connect_timeout = Duration::from_secs(10); @@ -115,10 +117,7 @@ impl ClientConfig { debug!("Connection fully established"); - Ok(ConnectedClient { - connection_params: self, - stream: Arc::new(Mutex::new(stream)), - }) + Ok(ConnectedClient::new(self, stream)) } async fn do_http_connect(&self, stream: &TcpStream) -> Result<(), Box> { @@ -166,7 +165,17 @@ impl ClientConfig { } } -impl ConnectedClient { +impl ConnectedClient +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub fn new(connection_params: ClientConfig, stream: T) -> Self { + ConnectedClient { + connection_params, + stream: Arc::new(Mutex::new(stream)), + } + } + pub async fn send_message(&self, message: &str) -> Result<(), Box> { let stream = self.stream.lock().await; let mut pinned = Pin::new(stream); diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index c03e78a5..666e4ff1 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -8,13 +8,15 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use tokio::net::TcpStream; use tokio::sync::{mpsc, Mutex}; +use tokio_openssl::SslStream; #[pyclass] pub struct SrpcClientConfig(ClientConfig); #[pyclass] -pub struct ConnectedSrpcClient(Arc>); +pub struct ConnectedSrpcClient(Arc>>>); #[pymethods] impl SrpcClientConfig { From b189c34e57a1c354db6f3e61550e116e25221e43 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sat, 14 Dec 2024 13:12:51 -0800 Subject: [PATCH 18/27] test: Add first test, and timeout so we dont hang on bad test --- rust/lib/srpc/client/Cargo.toml | 4 +- rust/lib/srpc/client/src/lib.rs | 16 +++- rust/lib/srpc/client/src/tests.rs | 1 + rust/lib/srpc/client/src/tests/lib.rs | 113 ++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 rust/lib/srpc/client/src/tests.rs create mode 100644 rust/lib/srpc/client/src/tests/lib.rs diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 0a3fe314..7d4a2e53 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -43,4 +43,6 @@ path = "examples/rust_client_example2.rs" required-features = [] [dev-dependencies] -tokio = { version = "1", features = ["full", "macros"] } +rstest = "0.23.0" +test-log = { version = "0.2.16", features = ["trace", "color"] } +tokio = { version = "1", features = ["full", "macros", "test-util"] } diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index a20f95e7..717259a3 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -6,7 +6,7 @@ use std::error::Error; use std::fmt; use std::pin::Pin; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; use tokio::sync::{mpsc, Mutex}; use tokio::time::{timeout, Duration}; @@ -15,6 +15,8 @@ use tokio_util::codec::{FramedRead, LinesCodec}; use tracing::debug; mod chunk_limiter; +#[cfg(test)] +mod tests; // Custom error type #[derive(Debug)] @@ -40,13 +42,19 @@ pub struct ClientConfig { pub struct ReceiveOptions { channel_buffer_size: usize, max_chunk_size: usize, + read_next_line_duration: Duration, } impl ReceiveOptions { - pub fn new(channel_buffer_size: usize, max_chunk_size: usize) -> Self { + pub fn new( + channel_buffer_size: usize, + max_chunk_size: usize, + read_next_line_duration: Duration, + ) -> Self { ReceiveOptions { channel_buffer_size, max_chunk_size, + read_next_line_duration, } } } @@ -56,6 +64,7 @@ impl Default for ReceiveOptions { ReceiveOptions { channel_buffer_size: 100, max_chunk_size: 16384, + read_next_line_duration: Duration::from_secs(10), } } } @@ -251,6 +260,7 @@ where let stream = Arc::clone(&self.stream); let (tx, rx) = mpsc::channel(opts.channel_buffer_size); let max_chunk_size = opts.max_chunk_size; + let read_next_line_duration = opts.read_next_line_duration; tokio::spawn(async move { let mut guard = stream.lock().await; @@ -258,7 +268,7 @@ where let buf_reader = BufReader::new(limited_reader); let mut framed = FramedRead::new(buf_reader, LinesCodec::new()); - while let Some(line_res) = framed.next().await { + while let Ok(Some(line_res)) = timeout(read_next_line_duration, framed.next()).await { let line_res = line_res.map_err(|e| Box::new(e) as Box); match line_res { diff --git a/rust/lib/srpc/client/src/tests.rs b/rust/lib/srpc/client/src/tests.rs new file mode 100644 index 00000000..e629be4c --- /dev/null +++ b/rust/lib/srpc/client/src/tests.rs @@ -0,0 +1 @@ +mod lib; diff --git a/rust/lib/srpc/client/src/tests/lib.rs b/rust/lib/srpc/client/src/tests/lib.rs new file mode 100644 index 00000000..5be614ec --- /dev/null +++ b/rust/lib/srpc/client/src/tests/lib.rs @@ -0,0 +1,113 @@ +use std::{error::Error, num::NonZeroU8}; + +use crate::{ClientConfig, ConnectedClient, ReceiveOptions}; + +use rstest::rstest; +use tokio::{ + io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream}, + sync::mpsc, +}; + +fn setup_test_client() -> (ConnectedClient, DuplexStream) { + let (client_stream, server_stream) = duplex(1024); + + let config = ClientConfig::new("example.com", 443, "/", "", ""); + (ConnectedClient::new(config, client_stream), server_stream) +} + +fn n_message(num: NonZeroU8) -> impl FnMut(&str) -> bool { + let mut seen = 0; + move |_msg: &str| { + if seen + 1 == num.get() { + false + } else { + seen += 1; + true + } + } +} + +fn one_message() -> impl FnMut(&str) -> bool { + n_message(NonZeroU8::new(1).unwrap()) +} + +async fn check_message( + server_message: &str, + rx: &mut mpsc::Receiver>>, +) { + if let Some(Ok(received_msg)) = rx.recv().await { + assert_eq!(received_msg, server_message.trim()); + } else { + panic!("Did not receive expected message from server"); + } +} + +async fn check_server( + client_message: &str, + server_stream: &mut DuplexStream, +) -> Result<(), Box> { + let mut server_buf = vec![0u8; client_message.len()]; + server_stream.read_exact(&mut server_buf).await?; + assert_eq!(&server_buf, client_message.as_bytes()); + Ok(()) +} + +#[test_log::test(rstest)] +#[tokio::test(start_paused = true)] +async fn test_connected_client_send_and_receive() -> Result<(), Box> { + let (connected_client, mut server_stream) = setup_test_client(); + + let client_message = "Hello from client\n"; + connected_client.send_message(client_message).await?; + + check_server(client_message, &mut server_stream).await?; + + let server_message = "Hello from server\n"; + server_stream.write_all(server_message.as_bytes()).await?; + + let should_continue = one_message(); + + let opts = ReceiveOptions::default(); + let mut rx = connected_client + .receive_message(false, should_continue, &opts) + .await?; + + check_message(server_message, &mut rx).await; + + Ok(()) +} + +#[test_log::test(rstest)] +#[tokio::test(start_paused = true)] +async fn test_connected_client_send_and_receive_stream() -> Result<(), Box> { + let (connected_client, mut server_stream) = setup_test_client(); + + let client_message = "Hello from client\n"; + connected_client.send_message(client_message).await?; + + check_server(client_message, &mut server_stream).await?; + + server_stream.write_all("\n".as_bytes()).await?; + + let should_continue = one_message(); + + let opts = ReceiveOptions::default(); + let mut rx = connected_client + .receive_message(true, should_continue, &opts) + .await?; + + check_message("", &mut rx).await; + + server_stream.write_all("first\n".as_bytes()).await?; + + server_stream.write_all("second\n".as_bytes()).await?; + + let should_continue = n_message(NonZeroU8::new(2).unwrap()); + let mut rx = connected_client + .receive_message(false, should_continue, &opts) + .await?; + + check_message("first", &mut rx).await; + check_message("second", &mut rx).await; + Ok(()) +} From d6df3890c06cf326dc394400bd6ca83e9a61bcb2 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sat, 14 Dec 2024 13:38:11 -0800 Subject: [PATCH 19/27] chore: Remove unused old receive_message --- rust/lib/srpc/client/src/lib.rs | 55 --------------------------------- 1 file changed, 55 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 717259a3..5f558ba2 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -193,61 +193,6 @@ where Ok(()) } - pub async fn receive_message_old( - &self, - expect_empty: bool, - mut should_continue: F, - ) -> Result>>, Box> - where - F: FnMut(&str) -> bool + Send + 'static, - { - let stream_clone = self.stream.clone(); - let (tx, rx) = mpsc::channel(100); - - tokio::spawn(async move { - loop { - let mut stream = stream_clone.lock().await; - let mut response = String::new(); - loop { - let mut buf = [0; 1024]; - match stream.read(&mut buf).await { - Ok(n) => { - let res = String::from_utf8_lossy(&buf[..n]); - response.push_str(&res); - debug!("ResponseT: {:?}", res); - if response.ends_with('\n') { - break; - } - } - Err(e) => { - let _ = tx.send(Err(Box::new(e) as Box)).await; - return; - } - } - } - let response = response.trim().to_string(); - - if expect_empty && !response.is_empty() { - let _ = tx - .send(Err(Box::new(CustomError(format!( - "Expected empty string, got: {:?}", - response - ))) as Box)) - .await; - return; - } - - let _ = tx.send(Ok(response.clone())).await; - - if !should_continue(&response) { - break; - } - } - }); - - Ok(rx) - } - pub async fn receive_message( &self, expect_empty: bool, From a8706bc42db6d2046e9a36959f0c3093ea1218fb Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sat, 14 Dec 2024 15:49:15 -0800 Subject: [PATCH 20/27] chore: update pyo3 --- rust/lib/srpc/client/Cargo.toml | 10 ++-- rust/lib/srpc/client/src/lib.rs | 2 +- rust/lib/srpc/client/src/python_bindings.rs | 63 ++++++++++++--------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 7d4a2e53..8fef76e2 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -19,18 +19,18 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } [dependencies.pyo3] -version = "0.18" +version = "0.23" features = ["extension-module"] optional = true -[dependencies.pyo3-asyncio] -version = "0.18" -features = ["tokio-runtime"] +[dependencies.pyo3-async-runtimes] +version = "0.23" +features = ["attributes", "tokio-runtime"] optional = true [features] default = [] -python = ["pyo3", "pyo3-asyncio"] +python = ["pyo3", "pyo3-async-runtimes"] [[example]] name = "rust_client_example" diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 5f558ba2..ca6f5a6d 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -292,7 +292,7 @@ use pyo3::prelude::*; #[cfg(feature = "python")] #[pymodule] -fn srpc_client(_py: Python, m: &PyModule) -> PyResult<()> { +fn srpc_client(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { use tracing::level_filters::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 666e4ff1..273d7b06 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -2,6 +2,7 @@ use crate::{ClientConfig, ConnectedClient, ReceiveOptions}; use futures::{Stream, StreamExt}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; +use pyo3::types::PyFunction; use serde_json::Value; use std::{ pin::Pin, @@ -25,9 +26,9 @@ impl SrpcClientConfig { SrpcClientConfig(ClientConfig::new(host, port, path, cert, key)) } - pub fn connect<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + pub fn connect<'p>(&self, py: Python<'p>) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async { + pyo3_async_runtimes::tokio::future_into_py(py, async move { client .connect() .await @@ -76,7 +77,7 @@ impl PyStream { fn __anext__(&self, py: Python) -> PyResult> { let streamer = self.streamer.clone(); - let future = pyo3_asyncio::tokio::future_into_py(py, async move { + let future = pyo3_async_runtimes::tokio::future_into_py(py, async move { let val = streamer.lock().await.next().await; match val { Some(Ok(val)) => Ok(val), @@ -129,7 +130,7 @@ impl PyValueStream { fn __anext__(&self, py: Python) -> PyResult> { let streamer = self.streamer.clone(); - let future = pyo3_asyncio::tokio::future_into_py(py, async move { + let future = pyo3_async_runtimes::tokio::future_into_py(py, async move { let val = streamer.lock().await.next().await; match val { Some(Ok(val)) => Ok(val.to_string()), @@ -143,9 +144,13 @@ impl PyValueStream { #[pymethods] impl ConnectedSrpcClient { - pub fn send_message<'p>(&self, py: Python<'p>, message: String) -> PyResult<&'p PyAny> { + pub fn send_message<'p>( + &'p self, + py: Python<'p>, + message: String, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { client .lock() .await @@ -160,9 +165,9 @@ impl ConnectedSrpcClient { py: Python<'p>, expect_empty: bool, should_continue: bool, - ) -> PyResult<&'p PyAny> { + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { let rx = client .lock() .await @@ -182,18 +187,17 @@ impl ConnectedSrpcClient { &self, py: Python<'p>, expect_empty: bool, - should_continue: &PyAny, - ) -> PyResult<&'p PyAny> { + should_continue: Py, + ) -> PyResult> { let client = self.0.clone(); - let should_continue = should_continue.to_object(py); + let should_continue = should_continue.clone_ref(py); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { let should_continue = move |response: &str| -> bool { Python::with_gil(|py| { - let func = should_continue.as_ref(py); - - func.call1((response,)) - .and_then(|v| v.extract::()) + should_continue + .call1(py, (response,)) + .and_then(|v| v.extract::(py)) .unwrap_or(false) }) }; @@ -208,9 +212,9 @@ impl ConnectedSrpcClient { }) } - pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult<&'p PyAny> { + pub fn send_json<'p>(&self, py: Python<'p>, payload: String) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { let value: Value = serde_json::from_str(&payload) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; client @@ -222,9 +226,13 @@ impl ConnectedSrpcClient { }) } - pub fn receive_json<'p>(&self, py: Python<'p>, should_continue: bool) -> PyResult<&'p PyAny> { + pub fn receive_json<'p>( + &self, + py: Python<'p>, + should_continue: bool, + ) -> PyResult> { let client = self.0.clone(); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { let rx = client .lock() .await @@ -241,18 +249,17 @@ impl ConnectedSrpcClient { pub fn receive_json_cb<'p>( &self, py: Python<'p>, - should_continue: &PyAny, - ) -> PyResult<&'p PyAny> { + should_continue: Py, + ) -> PyResult> { let client = self.0.clone(); - let should_continue = should_continue.to_object(py); + let should_continue = should_continue.clone_ref(py); - pyo3_asyncio::tokio::future_into_py(py, async move { + pyo3_async_runtimes::tokio::future_into_py(py, async move { let should_continue = move |response: &str| -> bool { Python::with_gil(|py| { - let func = should_continue.as_ref(py); - - func.call1((response,)) - .and_then(|v| v.extract::()) + should_continue + .call1(py, (response,)) + .and_then(|v| v.extract::(py)) .unwrap_or(false) }) }; From 380927f7af82f3f612609841d99ab75a05519fac Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 13:53:54 -0800 Subject: [PATCH 21/27] fix: restore streaming with timeout --- rust/lib/srpc/client/src/lib.rs | 63 ++++++++++++++------- rust/lib/srpc/client/src/python_bindings.rs | 4 +- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index ca6f5a6d..6ba3ce68 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -39,10 +39,12 @@ pub struct ClientConfig { key: String, } +#[cfg_attr(feature = "python", derive(FromPyObject))] pub struct ReceiveOptions { channel_buffer_size: usize, max_chunk_size: usize, read_next_line_duration: Duration, + should_continue_on_timeout: bool, } impl ReceiveOptions { @@ -50,11 +52,13 @@ impl ReceiveOptions { channel_buffer_size: usize, max_chunk_size: usize, read_next_line_duration: Duration, + should_continue_on_timeout: bool, ) -> Self { ReceiveOptions { channel_buffer_size, max_chunk_size, read_next_line_duration, + should_continue_on_timeout, } } } @@ -65,6 +69,7 @@ impl Default for ReceiveOptions { channel_buffer_size: 100, max_chunk_size: 16384, read_next_line_duration: Duration::from_secs(10), + should_continue_on_timeout: true, } } } @@ -206,6 +211,7 @@ where let (tx, rx) = mpsc::channel(opts.channel_buffer_size); let max_chunk_size = opts.max_chunk_size; let read_next_line_duration = opts.read_next_line_duration; + let should_continue_on_timeout = opts.should_continue_on_timeout; tokio::spawn(async move { let mut guard = stream.lock().await; @@ -213,32 +219,45 @@ where let buf_reader = BufReader::new(limited_reader); let mut framed = FramedRead::new(buf_reader, LinesCodec::new()); - while let Ok(Some(line_res)) = timeout(read_next_line_duration, framed.next()).await { - let line_res = line_res.map_err(|e| Box::new(e) as Box); - - match line_res { - Ok(line) => { - if expect_empty && !line.is_empty() { - let _ = tx - .send(Err(Box::new(CustomError(format!( - "Expected empty line, got: {:?}", - line - ))) - as Box)) - .await; - break; - } - - let _ = tx.send(Ok(line.clone())).await; - - if !should_continue(&line) { - break; + loop { + let result = timeout(read_next_line_duration, framed.next()).await; + match result { + Ok(Some(line_res)) => { + let line_res = line_res.map_err(|e| Box::new(e) as Box); + + match line_res { + Ok(line) => { + if expect_empty && !line.is_empty() { + let _ = tx + .send(Err(Box::new(CustomError(format!( + "Expected empty line, got: {:?}", + line + ))) + as Box)) + .await; + break; + } + + let _ = tx.send(Ok(line.clone())).await; + + if !should_continue(&line) { + break; + } + } + Err(err) => { + let _ = tx.send(Err(err)).await; + break; + } } } - Err(err) => { - let _ = tx.send(Err(err)).await; + Ok(None) => { break; } + Err(_) => { + if !should_continue_on_timeout { + break; + } + } } } }); diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 273d7b06..8eb389b1 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -246,10 +246,12 @@ impl ConnectedSrpcClient { }) } + #[pyo3(signature = (should_continue, opts=None))] pub fn receive_json_cb<'p>( &self, py: Python<'p>, should_continue: Py, + opts: Option, ) -> PyResult> { let client = self.0.clone(); let should_continue = should_continue.clone_ref(py); @@ -266,7 +268,7 @@ impl ConnectedSrpcClient { let rx = client .lock() .await - .receive_json(should_continue, &ReceiveOptions::default()) + .receive_json(should_continue, &opts.unwrap_or_default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; From 3830d54866d19f5d680af87ee0db43a00af84459 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 17:19:37 -0800 Subject: [PATCH 22/27] fix: Dont overread the buffer when expecting empty --- rust/lib/srpc/client/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 6ba3ce68..6ae5a5bc 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -209,7 +209,7 @@ where { let stream = Arc::clone(&self.stream); let (tx, rx) = mpsc::channel(opts.channel_buffer_size); - let max_chunk_size = opts.max_chunk_size; + let max_chunk_size = if expect_empty { 1 } else { opts.max_chunk_size }; let read_next_line_duration = opts.read_next_line_duration; let should_continue_on_timeout = opts.should_continue_on_timeout; From 9416beb004f212542d6cc1ee58ce60a0eecdcc5d Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 17:21:06 -0800 Subject: [PATCH 23/27] feat: Add helpers for reading back empty ack --- rust/lib/srpc/client/src/lib.rs | 46 +++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 6ae5a5bc..1f34ef1e 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -198,6 +198,29 @@ where Ok(()) } + pub async fn send_message_and_check(&self, message: &str) -> Result<(), Box> { + self.send_message(message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { + } else { + return Err(Box::new(CustomError("Expected empty line".to_string()))); + } + + Ok(()) + } + pub async fn receive_message( &self, expect_empty: bool, @@ -270,6 +293,29 @@ where self.send_message(&json_string).await } + pub async fn send_json_and_check(&self, payload: &Value) -> Result<(), Box> { + self.send_json(payload) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { + } else { + return Err(Box::new(CustomError("Expected empty line".to_string()))); + } + + Ok(()) + } + pub async fn receive_json( &self, should_continue: F, From d85cc5a6b01066cc92db53e9052c84a9f7e44769 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 17:44:03 -0800 Subject: [PATCH 24/27] feat: Add rudimentary request_reply api and example --- rust/lib/srpc/client/Cargo.toml | 2 + .../client/examples/python_client_example3.py | 89 +++++++++++++ .../client/examples/rust_client_example3.rs | 120 ++++++++++++++++++ rust/lib/srpc/client/src/lib.rs | 97 +++++++++++++- rust/lib/srpc/client/src/python_bindings.rs | 23 +++- rust/lib/srpc/client/srpc_client.pyi | 2 + 6 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 rust/lib/srpc/client/examples/python_client_example3.py create mode 100644 rust/lib/srpc/client/examples/rust_client_example3.rs diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 8fef76e2..37a61b41 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -8,10 +8,12 @@ name = "srpc_client" crate-type = ["cdylib", "rlib"] [dependencies] +async-trait = "0.1" bytes = "1" futures = "0.3" tokio = { version = "1", features = ["full"] } openssl = "0.10" +serde = {version = "1", features = ["derive"]} serde_json = "1" tokio-openssl = "0.6" tokio-util = { version = "0.7", features = ["codec"] } diff --git a/rust/lib/srpc/client/examples/python_client_example3.py b/rust/lib/srpc/client/examples/python_client_example3.py new file mode 100644 index 00000000..8d771901 --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example3.py @@ -0,0 +1,89 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build the Rust library: maturin build --features python +2. Install the wheel: pip install target/wheels/srpc_client-*.whl +3. Run this script: python examples/python_client_example.py +""" + +import asyncio +import json +import os +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + os.environ["EXAMPLE_3_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_3_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_3_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_3_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_3_SRPC_SERVER_KEY"], + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + message = "Hypervisor.ListVMs\n" + + # Send a message + print(f"Sending message: {message}") + await client.send_message(message) + print(f"Sent message: {message}") + + # Receive an empty response + print("Waiting for empty string response...") + responses = await client.receive_message(expect_empty=True, should_continue=False) + async for response in responses: + print(f"Received response: {response}") + + # Send a JSON message + payload = json.dumps( + { + "IgnoreStateMask": 0, + "OwnerGroups": [], + "OwnerUsers": [], + "Sort": True, + "VmTagsToMatch": {}, + } + ) + print(f"Sending payload: {payload}") + await client.send_json(payload) + print(f"Sent payload: {payload}") + + # Receive an empty response + print("Waiting for empty string response for payload...") + responses = await client.receive_message(expect_empty=True, should_continue=False) + async for response in responses: + print(f"Received response: {response}") + + # Receive responses + print("Waiting for response...") + responses = await client.receive_json_cb(should_continue=lambda _: False) + async for response in responses: + print(f"Received response: {json.loads(response)}") + + # Use RequestReply + print(f"Sending request_reply: {message}") + res = await client.request_reply( + message, + json.dumps( + { + "IgnoreStateMask": 0, + "OwnerGroups": [], + "OwnerUsers": [], + "Sort": True, + "VmTagsToMatch": {}, + } + ), + ) + print(f"Sent request_reply: {message}, got reply: {res}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/rust_client_example3.rs b/rust/lib/srpc/client/examples/rust_client_example3.rs new file mode 100644 index 00000000..5e6dc594 --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example3.rs @@ -0,0 +1,120 @@ +use std::{collections::HashMap, error::Error}; + +use srpc_client::{ClientConfig, CustomError, ReceiveOptions, SimpleValue}; +use tracing::{error, info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + &std::env::var("EXAMPLE_3_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_3_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_3_SRPC_SERVER_KEY")?, + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + let message = "Hypervisor.ListVMs\n"; + + // Send a message + info!("Sending message: {:?}", message); + client.send_message(message).await?; + info!("Sent message: {:?}", message); + + // Receive an empty response + info!("Waiting for empty string response..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + #[derive(Debug, serde::Serialize)] + struct ListVMsRequest { + ignore_state_mask: u32, + owner_groups: Vec, + owner_users: Vec, + sort: bool, + vm_tags_to_match: HashMap, + } + + #[derive(Debug, serde::Deserialize)] + struct ListVMsResponse { + ip_addresses: Vec, + } + + let request = ListVMsRequest { + ignore_state_mask: 0, + owner_groups: vec![], + owner_users: vec![], + sort: false, + vm_tags_to_match: HashMap::new(), + }; + + // Send a JSON message + info!("Sending payload: {:?}", request); + client.send_json(&serde_json::to_value(&request)?).await?; + info!("Sent payload: {:?}", request); + + // Receive an empty response + info!("Waiting for empty string response for payload..."); + let mut rx = client + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + // Receive responses + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await?; + while let Some(result) = rx.recv().await { + match result + .and_then(|response| { + serde_json::from_value::(response) + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + }) + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + { + Ok(response) => info!("Received response: {:?}", response), + Err(e) => error!("Error receiving message: {:?}", e), + } + } + + info!("Sending request_reply: {}", message); + let res = client + .request_reply::(message, serde_json::to_value(&request)?) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!( + "Sent request_reply: {}, got reply: {:?}", + message, + serde_json::to_string(&res)? + ); + + Ok(()) +} diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 1f34ef1e..37f52a08 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use chunk_limiter::ChunkLimiter; use futures::StreamExt; use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; @@ -20,7 +21,7 @@ mod tests; // Custom error type #[derive(Debug)] -struct CustomError(String); +pub struct CustomError(pub String); impl fmt::Display for CustomError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -82,6 +83,70 @@ where stream: Arc>, } +#[async_trait] +pub trait RequestReply { + type Request; + type Reply; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static; +} + +pub struct SimpleValue; + +#[async_trait] +impl RequestReply for SimpleValue { + type Request = serde_json::Value; + type Reply = serde_json::Value; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + client.send_json_and_check(&payload).await?; + + let mut rx = client + .receive_json(|_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let json_value = rx.recv().await.ok_or_else(|| { + Box::new(CustomError("Expected JSON value".to_string())) as Box + })??; + Ok(json_value) + } +} + +pub struct StreamValue; + +#[async_trait] +impl RequestReply for StreamValue { + type Request = serde_json::Value; + type Reply = mpsc::Receiver>>; + + async fn request_reply( + client: &ConnectedClient, + payload: Self::Request, + ) -> Result> + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + client.send_json_and_check(&payload).await?; + + let rx = client + .receive_json(|_| true, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + Ok(rx) + } +} + impl ClientConfig { pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { ClientConfig { @@ -190,6 +255,36 @@ where } } + pub async fn request_reply( + &self, + method: &str, + payload: R::Request, + ) -> Result> + where + R: RequestReply, + { + self.send_message(method) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let mut rx = self + .receive_message(true, |_| false, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + if rx + .recv() + .await + .ok_or_else(|| { + Box::new(CustomError("Expected response".to_string())) as Box + })?? + .is_empty() + { + } else { + return Err(Box::new(CustomError("Expected empty line".to_string()))); + } + + R::request_reply(self, payload).await + } + pub async fn send_message(&self, message: &str) -> Result<(), Box> { let stream = self.stream.lock().await; let mut pinned = Pin::new(stream); diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 8eb389b1..6235f99d 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,4 +1,4 @@ -use crate::{ClientConfig, ConnectedClient, ReceiveOptions}; +use crate::{ClientConfig, ConnectedClient, ReceiveOptions, SimpleValue}; use futures::{Stream, StreamExt}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; @@ -277,4 +277,25 @@ impl ConnectedSrpcClient { })) }) } + + pub fn request_reply<'p>( + &self, + py: Python<'p>, + method: String, + payload: String, + ) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value: Value = serde_json::from_str(&payload) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let response = client + .lock() + .await + .request_reply::(&method, value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(response.to_string()) + }) + } } diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi index f9b6911e..c6913ba6 100644 --- a/rust/lib/srpc/client/srpc_client.pyi +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -1,5 +1,6 @@ from typing import Callable, List +type JsonStr = str class SrpcClientConfig: def __init__(self, host: str, port: int, path: str, cert: str, key: str) -> None: ... @@ -12,3 +13,4 @@ class ConnectedSrpcClient: async def send_json(self, payload: str) -> None: ... async def receive_json(self, should_continue: bool) -> List[str]: ... async def receive_json_cb(self, should_continue: Callable[[str], bool]) -> List[str]: ... + async def request_reply(self, message: str, payload: JsonStr) -> JsonStr: ... From 882a1529a9c4169be02cfa168601dbee9b0ed436 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 17:46:25 -0800 Subject: [PATCH 25/27] feat: Expose *_and_check methods to python --- rust/lib/srpc/client/src/python_bindings.rs | 30 +++++++++++++++++++++ rust/lib/srpc/client/srpc_client.pyi | 2 ++ 2 files changed, 32 insertions(+) diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 6235f99d..17a3a214 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -160,6 +160,22 @@ impl ConnectedSrpcClient { }) } + pub fn send_message_and_check<'p>( + &'p self, + py: Python<'p>, + message: String, + ) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + client + .lock() + .await + .send_message_and_check(&message) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } + pub fn receive_message<'p>( &self, py: Python<'p>, @@ -226,6 +242,20 @@ impl ConnectedSrpcClient { }) } + pub fn send_json_and_check<'p>(&self, py: Python<'p>, payload: String) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let value: Value = serde_json::from_str(&payload) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + client + .lock() + .await + .send_json_and_check(&value) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string())) + }) + } + pub fn receive_json<'p>( &self, py: Python<'p>, diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi index c6913ba6..33b714dd 100644 --- a/rust/lib/srpc/client/srpc_client.pyi +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -8,9 +8,11 @@ class SrpcClientConfig: class ConnectedSrpcClient: async def send_message(self, message: str) -> None: ... + async def send_message_and_check(self, message: str) -> None: ... async def receive_message(self, expect_empty: bool, should_continue: bool) -> List[str]: ... async def receive_message_cb(self, expect_empty: bool, should_continue: Callable[[str], bool]) -> List[str]: ... async def send_json(self, payload: str) -> None: ... + async def send_json_and_check(self, payload: str) -> None: ... async def receive_json(self, should_continue: bool) -> List[str]: ... async def receive_json_cb(self, should_continue: Callable[[str], bool]) -> List[str]: ... async def request_reply(self, message: str, payload: JsonStr) -> JsonStr: ... From 11eef659b6410ff3b0e4a310c12ec10415352daa Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Sun, 15 Dec 2024 17:54:40 -0800 Subject: [PATCH 26/27] docs: Better preamble for examples --- rust/lib/srpc/client/Cargo.toml | 5 +++++ .../srpc/client/examples/python_client_example.py | 12 +++++++++--- .../srpc/client/examples/python_client_example2.py | 12 +++++++++--- .../srpc/client/examples/python_client_example3.py | 12 +++++++++--- rust/lib/srpc/client/examples/rust_client_example.rs | 9 +++++++++ .../lib/srpc/client/examples/rust_client_example2.rs | 9 +++++++++ .../lib/srpc/client/examples/rust_client_example3.rs | 9 +++++++++ 7 files changed, 59 insertions(+), 9 deletions(-) diff --git a/rust/lib/srpc/client/Cargo.toml b/rust/lib/srpc/client/Cargo.toml index 37a61b41..baddae33 100644 --- a/rust/lib/srpc/client/Cargo.toml +++ b/rust/lib/srpc/client/Cargo.toml @@ -44,6 +44,11 @@ name = "rust_client_example2" path = "examples/rust_client_example2.rs" required-features = [] +[[example]] +name = "rust_client_example3" +path = "examples/rust_client_example3.rs" +required-features = [] + [dev-dependencies] rstest = "0.23.0" test-log = { version = "0.2.16", features = ["trace", "color"] } diff --git a/rust/lib/srpc/client/examples/python_client_example.py b/rust/lib/srpc/client/examples/python_client_example.py index 1aded322..52f1fa52 100644 --- a/rust/lib/srpc/client/examples/python_client_example.py +++ b/rust/lib/srpc/client/examples/python_client_example.py @@ -2,9 +2,15 @@ This example demonstrates how to use the srpc_client Python bindings. To run this example: -1. Build the Rust library: maturin build --features python -2. Install the wheel: pip install target/wheels/srpc_client-*.whl -3. Run this script: python examples/python_client_example.py +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_1_SRPC_SERVER_HOST= \ + EXAMPLE_1_SRPC_SERVER_PORT= \ + EXAMPLE_1_SRPC_SERVER_ENPOINT= \ + EXAMPLE_1_SRPC_SERVER_CERT= \ + EXAMPLE_1_SRPC_SERVER_KEY= \ + python examples/python_client_example.py """ import asyncio diff --git a/rust/lib/srpc/client/examples/python_client_example2.py b/rust/lib/srpc/client/examples/python_client_example2.py index 594ba011..ea271912 100644 --- a/rust/lib/srpc/client/examples/python_client_example2.py +++ b/rust/lib/srpc/client/examples/python_client_example2.py @@ -2,9 +2,15 @@ This example demonstrates how to use the srpc_client Python bindings. To run this example: -1. Build the Rust library: maturin build --features python -2. Install the wheel: pip install target/wheels/srpc_client-*.whl -3. Run this script: python examples/python_client_example.py +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_2_SRPC_SERVER_HOST= \ + EXAMPLE_2_SRPC_SERVER_PORT= \ + EXAMPLE_2_SRPC_SERVER_ENPOINT= \ + EXAMPLE_2_SRPC_SERVER_CERT= \ + EXAMPLE_2_SRPC_SERVER_KEY= \ + python examples/python_client_example2.py """ import asyncio diff --git a/rust/lib/srpc/client/examples/python_client_example3.py b/rust/lib/srpc/client/examples/python_client_example3.py index 8d771901..0c80bf5a 100644 --- a/rust/lib/srpc/client/examples/python_client_example3.py +++ b/rust/lib/srpc/client/examples/python_client_example3.py @@ -2,9 +2,15 @@ This example demonstrates how to use the srpc_client Python bindings. To run this example: -1. Build the Rust library: maturin build --features python -2. Install the wheel: pip install target/wheels/srpc_client-*.whl -3. Run this script: python examples/python_client_example.py +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_3_SRPC_SERVER_HOST= \ + EXAMPLE_3_SRPC_SERVER_PORT= \ + EXAMPLE_3_SRPC_SERVER_ENPOINT= \ + EXAMPLE_3_SRPC_SERVER_CERT= \ + EXAMPLE_3_SRPC_SERVER_KEY= \ + python examples/python_client_example3.py """ import asyncio diff --git a/rust/lib/srpc/client/examples/rust_client_example.rs b/rust/lib/srpc/client/examples/rust_client_example.rs index 88736e85..f0499bc2 100644 --- a/rust/lib/srpc/client/examples/rust_client_example.rs +++ b/rust/lib/srpc/client/examples/rust_client_example.rs @@ -1,3 +1,12 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_1_SRPC_SERVER_HOST= \ + EXAMPLE_1_SRPC_SERVER_PORT= \ + EXAMPLE_1_SRPC_SERVER_ENPOINT= \ + EXAMPLE_1_SRPC_SERVER_CERT= \ + EXAMPLE_1_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example +**/ use serde_json::json; use srpc_client::{ClientConfig, ReceiveOptions}; use tracing::{error, info, level_filters::LevelFilter}; diff --git a/rust/lib/srpc/client/examples/rust_client_example2.rs b/rust/lib/srpc/client/examples/rust_client_example2.rs index e4b18ca7..cfc6e788 100644 --- a/rust/lib/srpc/client/examples/rust_client_example2.rs +++ b/rust/lib/srpc/client/examples/rust_client_example2.rs @@ -1,3 +1,12 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_2_SRPC_SERVER_HOST= \ + EXAMPLE_2_SRPC_SERVER_PORT= \ + EXAMPLE_2_SRPC_SERVER_ENPOINT= \ + EXAMPLE_2_SRPC_SERVER_CERT= \ + EXAMPLE_2_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example2 +**/ use srpc_client::{ClientConfig, ReceiveOptions}; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/rust/lib/srpc/client/examples/rust_client_example3.rs b/rust/lib/srpc/client/examples/rust_client_example3.rs index 5e6dc594..afbeba9e 100644 --- a/rust/lib/srpc/client/examples/rust_client_example3.rs +++ b/rust/lib/srpc/client/examples/rust_client_example3.rs @@ -1,3 +1,12 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_3_SRPC_SERVER_HOST= \ + EXAMPLE_3_SRPC_SERVER_PORT= \ + EXAMPLE_3_SRPC_SERVER_ENPOINT= \ + EXAMPLE_3_SRPC_SERVER_CERT= \ + EXAMPLE_3_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example3 +**/ use std::{collections::HashMap, error::Error}; use srpc_client::{ClientConfig, CustomError, ReceiveOptions, SimpleValue}; From fda0805b00e16bdb09ee6dbd12d8382f6c046d75 Mon Sep 17 00:00:00 2001 From: Leonidas Loucas Date: Mon, 16 Dec 2024 20:42:22 -0800 Subject: [PATCH 27/27] wip: try to make Conn api --- .../client/examples/python_client_example4.py | 54 ++++++++++++ .../client/examples/rust_client_example4.rs | 71 ++++++++++++++++ rust/lib/srpc/client/src/lib.rs | 72 +++++++++++++++- rust/lib/srpc/client/src/python_bindings.rs | 83 ++++++++++++++----- rust/lib/srpc/client/srpc_client.pyi | 5 ++ 5 files changed, 263 insertions(+), 22 deletions(-) create mode 100644 rust/lib/srpc/client/examples/python_client_example4.py create mode 100644 rust/lib/srpc/client/examples/rust_client_example4.rs diff --git a/rust/lib/srpc/client/examples/python_client_example4.py b/rust/lib/srpc/client/examples/python_client_example4.py new file mode 100644 index 00000000..a516c69a --- /dev/null +++ b/rust/lib/srpc/client/examples/python_client_example4.py @@ -0,0 +1,54 @@ +""" +This example demonstrates how to use the srpc_client Python bindings. + +To run this example: +1. Build and install the Rust python library: maturin develop --features python +3. Run this script: + RUST_LOG=trace \ + EXAMPLE_4_SRPC_SERVER_HOST= \ + EXAMPLE_4_SRPC_SERVER_PORT= \ + EXAMPLE_4_SRPC_SERVER_ENPOINT= \ + EXAMPLE_4_SRPC_SERVER_CERT= \ + EXAMPLE_4_SRPC_SERVER_KEY= \ + python examples/python_client_example4.py +""" + +import asyncio +import json +import os +from srpc_client import SrpcClientConfig + + +async def main(): + print("Starting client..") + + # Create a new ClientConfig instance + client = SrpcClientConfig( + os.environ["EXAMPLE_4_SRPC_SERVER_HOST"], + int(os.environ["EXAMPLE_4_SRPC_SERVER_PORT"]), + os.environ["EXAMPLE_4_SRPC_SERVER_ENPOINT"], + os.environ["EXAMPLE_4_SRPC_SERVER_CERT"], + os.environ["EXAMPLE_4_SRPC_SERVER_KEY"], + ) + + # Connect to the server + client = await client.connect() + print("Connected to server") + + # Send a message + message = "Hypervisor.GetUpdates\n" + print(f"Calling server with message: {message}") + conn = await client.call(message) + response = await conn.decode() + print(f"Received response: {json.loads(response)}") + await conn.close() + + print(f"Calling server with message again: {message}") + conn2 = await client.call(message) + response = await conn2.decode() + print(f"Received response: {json.loads(response)}") + await conn2.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rust/lib/srpc/client/examples/rust_client_example4.rs b/rust/lib/srpc/client/examples/rust_client_example4.rs new file mode 100644 index 00000000..cc50a463 --- /dev/null +++ b/rust/lib/srpc/client/examples/rust_client_example4.rs @@ -0,0 +1,71 @@ +/** This example demonstrates how to use the srpc_client Rust bindings. + RUST_LOG=trace \ + EXAMPLE_4_SRPC_SERVER_HOST= \ + EXAMPLE_4_SRPC_SERVER_PORT= \ + EXAMPLE_4_SRPC_SERVER_ENPOINT= \ + EXAMPLE_4_SRPC_SERVER_CERT= \ + EXAMPLE_4_SRPC_SERVER_KEY= \ + cargo run --example rust_client_example4 +**/ +use std::{error::Error, sync::Arc}; + +use srpc_client::{ClientConfig, CustomError}; +use tokio::sync::Mutex; +use tracing::{info, level_filters::LevelFilter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(tracing_subscriber::fmt::Layer::default().compact()) + .init(); + + info!("Starting client..."); + + // Create a new ClientConfig instance + let config = ClientConfig::new( + &std::env::var("EXAMPLE_4_SRPC_SERVER_HOST")?, + std::env::var("EXAMPLE_4_SRPC_SERVER_PORT")?.parse()?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_ENPOINT")?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_CERT")?, + &std::env::var("EXAMPLE_4_SRPC_SERVER_KEY")?, + ); + + // Connect to the server + let client = config.connect().await?; + info!("Connected to server"); + + let message = "Hypervisor.GetUpdates\n"; + + let safe_client = Arc::new(Mutex::new(client)); + let guard = safe_client.lock_owned().await; + info!("Calling server with message: {:?}", message); + let mut conn = srpc_client::call(guard, message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let val = conn + .decode() + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!("Received response: {:?}", val); + + let guard = conn.close(); + + info!("Calling server with message again: {:?}", message); + let mut conn2 = srpc_client::call(guard, message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + let val = conn2 + .decode() + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + info!("Received response: {:?}", val); + let _guard = conn2.close(); + + Ok(()) +} diff --git a/rust/lib/srpc/client/src/lib.rs b/rust/lib/srpc/client/src/lib.rs index 37f52a08..36b7f1a7 100644 --- a/rust/lib/srpc/client/src/lib.rs +++ b/rust/lib/srpc/client/src/lib.rs @@ -3,13 +3,14 @@ use chunk_limiter::ChunkLimiter; use futures::StreamExt; use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; use serde_json::Value; +use std::borrow::BorrowMut; use std::error::Error; use std::fmt; use std::pin::Pin; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, Mutex, OwnedMutexGuard}; use tokio::time::{timeout, Duration}; use tokio_openssl::SslStream; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -147,6 +148,65 @@ impl RequestReply for StreamValue { } } +pub struct Conn +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + guard: Option>>, + rx: mpsc::Receiver>>, +} + +impl Conn +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub async fn new( + guard: tokio::sync::OwnedMutexGuard>, + ) -> Result> { + let rx = guard + .receive_json(|_| true, &ReceiveOptions::default()) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box)?; + + Ok(Conn { + guard: Some(guard), + rx, + }) + } + + pub async fn encode(&self, message: Value) -> Result<(), Box> { + self.guard + .as_ref() + .unwrap() + .send_json(&message) + .await + .map_err(|e| Box::new(CustomError(e.to_string())) as Box) + } + + pub async fn decode(&mut self) -> Result> { + let rx = self.rx.borrow_mut(); + let json_value = rx.recv().await.ok_or_else(|| { + Box::new(CustomError("Expected JSON value".to_string())) as Box + })??; + Ok(json_value) + } + + pub fn close(&mut self) -> OwnedMutexGuard> { + self.guard.take().unwrap() + } +} + +pub async fn call( + client: OwnedMutexGuard>, + method: &str, +) -> Result, Box> +where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + client.send_message_and_check(method).await?; + Conn::new(client).await +} + impl ClientConfig { pub fn new(host: &str, port: u16, path: &str, cert: &str, key: &str) -> Self { ClientConfig { @@ -285,6 +345,15 @@ where R::request_reply(self, payload).await } + // pub async fn call( + // self, + // method: &str, + // ) -> Result, Box> + // { + // self.send_message_and_check(method).await?; + // Conn::new(self.stream.lock_owned().await).await + // } + pub async fn send_message(&self, message: &str) -> Result<(), Box> { let stream = self.stream.lock().await; let mut pinned = Pin::new(stream); @@ -467,5 +536,6 @@ fn srpc_client(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/rust/lib/srpc/client/src/python_bindings.rs b/rust/lib/srpc/client/src/python_bindings.rs index 17a3a214..cfffadd2 100644 --- a/rust/lib/srpc/client/src/python_bindings.rs +++ b/rust/lib/srpc/client/src/python_bindings.rs @@ -1,4 +1,4 @@ -use crate::{ClientConfig, ConnectedClient, ReceiveOptions, SimpleValue}; +use crate::{ClientConfig, Conn, ConnectedClient, ReceiveOptions, SimpleValue}; use futures::{Stream, StreamExt}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration}; use pyo3::prelude::*; @@ -19,6 +19,9 @@ pub struct SrpcClientConfig(ClientConfig); #[pyclass] pub struct ConnectedSrpcClient(Arc>>>); +#[pyclass] +pub struct SrpcMethodCallConn(Arc>>>); + #[pymethods] impl SrpcClientConfig { #[new] @@ -33,7 +36,9 @@ impl SrpcClientConfig { .connect() .await .map_err(|e| PyRuntimeError::new_err(e.to_string())) - .map(|c| ConnectedSrpcClient(Arc::new(Mutex::new(c)))) + .map(|c| { + ConnectedSrpcClient(Arc::new(Mutex::new(c))) + }) }) } } @@ -151,9 +156,8 @@ impl ConnectedSrpcClient { ) -> PyResult> { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; client - .lock() - .await .send_message(&message) .await .map_err(|e| PyRuntimeError::new_err(e.to_string())) @@ -167,9 +171,8 @@ impl ConnectedSrpcClient { ) -> PyResult> { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; client - .lock() - .await .send_message_and_check(&message) .await .map_err(|e| PyRuntimeError::new_err(e.to_string())) @@ -184,9 +187,8 @@ impl ConnectedSrpcClient { ) -> PyResult> { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; let rx = client - .lock() - .await .receive_message( expect_empty, move |_| should_continue, @@ -217,9 +219,8 @@ impl ConnectedSrpcClient { .unwrap_or(false) }) }; + let client = client.lock().await; let rx = client - .lock() - .await .receive_message(expect_empty, should_continue, &ReceiveOptions::default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -233,23 +234,25 @@ impl ConnectedSrpcClient { pyo3_async_runtimes::tokio::future_into_py(py, async move { let value: Value = serde_json::from_str(&payload) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let client = client.lock().await; client - .lock() - .await .send_json(&value) .await .map_err(|e| PyRuntimeError::new_err(e.to_string())) }) } - pub fn send_json_and_check<'p>(&self, py: Python<'p>, payload: String) -> PyResult> { + pub fn send_json_and_check<'p>( + &self, + py: Python<'p>, + payload: String, + ) -> PyResult> { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { let value: Value = serde_json::from_str(&payload) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let client = client.lock().await; client - .lock() - .await .send_json_and_check(&value) .await .map_err(|e| PyRuntimeError::new_err(e.to_string())) @@ -263,9 +266,8 @@ impl ConnectedSrpcClient { ) -> PyResult> { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let client = client.lock().await; let rx = client - .lock() - .await .receive_json(move |_| should_continue, &ReceiveOptions::default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -295,9 +297,8 @@ impl ConnectedSrpcClient { .unwrap_or(false) }) }; + let client = client.lock().await; let rx = client - .lock() - .await .receive_json(should_continue, &opts.unwrap_or_default()) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -319,13 +320,53 @@ impl ConnectedSrpcClient { pyo3_async_runtimes::tokio::future_into_py(py, async move { let value: Value = serde_json::from_str(&payload) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let client = client.lock().await; let response = client - .lock() - .await .request_reply::(&method, value) .await .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + // TODO: Figure out how to marshall this as a python dict Ok(response.to_string()) }) } + + pub fn call<'p>(&self, py: Python<'p>, method: String) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let guard = client.lock_owned().await; + let conn = crate::call(guard, &method) + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(SrpcMethodCallConn(Arc::new( + Mutex::new(conn), + ))) + }) + } +} + +#[pymethods] +impl SrpcMethodCallConn { + pub fn decode<'p>(&self, py: Python<'p>) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = client.lock_owned().await; + let response = guard + .decode() + .await + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(response.to_string()) + }) + } + + pub fn close<'p>(&mut self, py: Python<'p>) -> PyResult> { + let client = self.0.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut guard = client.lock_owned().await; + guard.close(); + Ok(()) + }) + } } diff --git a/rust/lib/srpc/client/srpc_client.pyi b/rust/lib/srpc/client/srpc_client.pyi index 33b714dd..abed9fe1 100644 --- a/rust/lib/srpc/client/srpc_client.pyi +++ b/rust/lib/srpc/client/srpc_client.pyi @@ -16,3 +16,8 @@ class ConnectedSrpcClient: async def receive_json(self, should_continue: bool) -> List[str]: ... async def receive_json_cb(self, should_continue: Callable[[str], bool]) -> List[str]: ... async def request_reply(self, message: str, payload: JsonStr) -> JsonStr: ... + async def call(self, message: str) -> "SrpcMethodCallConn": ... + +class SrpcMethodCallConn: + async def decode(self) -> JsonStr: ... + async def close(self) -> None: ...