From 398f66756eaf6d1ac8937bc07fa02fd50315923c Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Sun, 2 Feb 2025 20:29:14 +0100 Subject: [PATCH] feat(client): make body reusable in some cases --- client/src/body.rs | 31 +++++++++++++++++++++++++++++ client/src/middleware/redirect.rs | 4 ++-- client/src/request.rs | 18 +++++++---------- client/src/service.rs | 33 +++++++++++++++++++------------ 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/client/src/body.rs b/client/src/body.rs index c629cafee..341e63cd6 100644 --- a/client/src/body.rs +++ b/client/src/body.rs @@ -137,3 +137,34 @@ where self.body.size_hint() } } + +impl Default for RequestBody { + fn default() -> Self { + Self::None + } +} + +pub enum RequestBody { + Reusable(Bytes), + Stream(Option), + None, +} + +impl RequestBody { + pub fn as_stream(&mut self) -> BoxBody { + match self { + Self::Reusable(bytes) => BoxBody::new(Once::new(Bytes::clone(bytes))), + Self::None => BoxBody::new(NoneBody::default()), + Self::Stream(stream) => { + let stream = stream.take(); + + stream.unwrap_or_else(|| BoxBody::new(NoneBody::default())) + } + } + } + + pub fn into_reusable(self) -> Self { + // @TODO ? + unimplemented!() + } +} diff --git a/client/src/middleware/redirect.rs b/client/src/middleware/redirect.rs index f0dd7b921..487eb5ac9 100644 --- a/client/src/middleware/redirect.rs +++ b/client/src/middleware/redirect.rs @@ -1,5 +1,5 @@ use crate::{ - body::BoxBody, + body::RequestBody, error::{Error, InvalidUri}, http::{ header::{CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, LOCATION, TRANSFER_ENCODING}, @@ -41,7 +41,7 @@ where method = Method::GET; } - *req.body_mut() = BoxBody::default(); + *req.body_mut() = RequestBody::default(); for header in &[TRANSFER_ENCODING, CONTENT_ENCODING, CONTENT_TYPE, CONTENT_LENGTH] { headers.remove(header); diff --git a/client/src/request.rs b/client/src/request.rs index c45df7d6c..b434919f4 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -3,7 +3,7 @@ use core::{marker::PhantomData, time::Duration}; use futures_core::Stream; use crate::{ - body::{BodyError, BoxBody, Once}, + body::{BodyError, BoxBody, RequestBody}, bytes::Bytes, client::Client, error::Error, @@ -18,7 +18,7 @@ use crate::{ /// builder type for [http::Request] with extended functionalities. pub struct RequestBuilder<'a, M = marker::Http> { - pub(crate) req: http::Request, + pub(crate) req: http::Request, err: Vec, client: &'a Client, timeout: Duration, @@ -74,7 +74,7 @@ impl RequestBuilder<'_, marker::Http> { let bytes = Bytes::from(body); let val = HeaderValue::from(bytes.len()); self.headers_mut().insert(CONTENT_LENGTH, val); - self.map_body(Once::new(bytes)) + self.map_body(RequestBody::Reusable(bytes)) } /// Use streaming type as request body. @@ -84,7 +84,7 @@ impl RequestBuilder<'_, marker::Http> { B: Stream> + Send + 'static, E: Into, { - self.map_body(body) + self.map_body(RequestBody::Stream(Some(BoxBody::new(body)))) } /// Finish request builder and send it to server. @@ -100,7 +100,7 @@ impl<'a, M> RequestBuilder<'a, M> { E: Into, { Self { - req: req.map(BoxBody::new), + req: req.map(|_| RequestBody::default()), err: Vec::new(), client, timeout: client.timeout_config.request_timeout, @@ -210,12 +210,8 @@ impl<'a, M> RequestBuilder<'a, M> { self } - fn map_body(mut self, b: B) -> RequestBuilder<'a, M> - where - B: Stream> + Send + 'static, - E: Into, - { - self.req = self.req.map(|_| BoxBody::new(b)); + fn map_body(mut self, b: RequestBody) -> RequestBuilder<'a, M> { + self.req = self.req.map(|_| b); self } } diff --git a/client/src/service.rs b/client/src/service.rs index cd56adcc4..9546a00b8 100644 --- a/client/src/service.rs +++ b/client/src/service.rs @@ -1,7 +1,7 @@ use core::{future::Future, pin::Pin, time::Duration}; use crate::{ - body::BoxBody, + body::RequestBody, client::Client, connect::Connect, error::Error, @@ -65,7 +65,7 @@ where /// /// [RequestBuilder]: crate::request::RequestBuilder pub struct ServiceRequest<'r, 'c> { - pub req: &'r mut Request, + pub req: &'r mut Request, pub client: &'c Client, pub timeout: Duration, } @@ -86,13 +86,20 @@ pub(crate) fn base_service() -> HttpService { use crate::{error::TimeoutError, timeout::Timeout}; let ServiceRequest { req, client, timeout } = req; + let cloned_body = match req.body() { + RequestBody::Reusable(body) => RequestBody::Reusable(body.clone()), + _ => RequestBody::default(), + }; - let uri = Uri::try_parse(req.uri())?; + let mut send_req = core::mem::take(req).map(|mut b| b.as_stream()); + *req.body_mut() = cloned_body; + + let uri = Uri::try_parse(send_req.uri())?; // temporary version to record possible version downgrade/upgrade happens when making connections. // alpn protocol and alt-svc header are possible source of version change. #[allow(unused_mut)] - let mut version = req.version(); + let mut version = send_req.version(); let mut connect = Connect::new(uri); @@ -103,12 +110,12 @@ pub(crate) fn base_service() -> HttpService { Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await { shared::AcquireOutput::Conn(mut _conn) => { let mut _timer = Box::pin(tokio::time::sleep(timeout)); - *req.version_mut() = version; + *send_req.version_mut() = version; #[allow(unreachable_code)] return match _conn.conn { #[cfg(feature = "http2")] crate::connection::ConnectionShared::H2(ref mut conn) => { - match crate::h2::proto::send(conn, _date, core::mem::take(req)) + match crate::h2::proto::send(conn, _date, send_req) .timeout(_timer.as_mut()) .await { @@ -128,7 +135,7 @@ pub(crate) fn base_service() -> HttpService { } #[cfg(feature = "http3")] crate::connection::ConnectionShared::H3(ref mut conn) => { - let res = crate::h3::proto::send(conn, _date, core::mem::take(req)) + let res = crate::h3::proto::send(conn, _date, send_req) .timeout(_timer.as_mut()) .await .map_err(|_| TimeoutError::Request)??; @@ -214,12 +221,12 @@ pub(crate) fn base_service() -> HttpService { }, version => match client.exclusive_pool.acquire(&connect.uri).await { exclusive::AcquireOutput::Conn(mut _conn) => { - *req.version_mut() = version; + *send_req.version_mut() = version; #[cfg(feature = "http1")] { let mut timer = Box::pin(tokio::time::sleep(timeout)); - let res = crate::h1::proto::send(&mut *_conn, _date, req) + let res = crate::h1::proto::send(&mut *_conn, _date, &mut send_req) .timeout(timer.as_mut()) .await; @@ -273,7 +280,7 @@ mod test { use std::sync::Arc; use crate::{ - body::{BoxBody, ResponseBody}, + body::{RequestBody, ResponseBody}, client::Client, error::Error, http::{self, Request}, @@ -293,14 +300,14 @@ mod test { pub(crate) struct HttpServiceMockHandle(Client); - type HandlerFn = Arc) -> Result, Error> + Send + Sync>; + type HandlerFn = Arc) -> Result, Error> + Send + Sync>; impl HttpServiceMockHandle { /// compose a service request with given http request and it's mocked server side handler function pub(crate) fn mock<'r, 'c>( &'c self, - req: &'r mut Request, - handler: impl Fn(Request) -> Result, Error> + Send + Sync + 'static, + req: &'r mut Request, + handler: impl Fn(Request) -> Result, Error> + Send + Sync + 'static, ) -> ServiceRequest<'r, 'c> { req.extensions_mut().insert(Arc::new(handler) as HandlerFn); ServiceRequest {