diff --git a/http-encoding/src/encode.rs b/http-encoding/src/encode.rs index 2a8fd97b0..efa18a7c6 100644 --- a/http-encoding/src/encode.rs +++ b/http-encoding/src/encode.rs @@ -1,7 +1,7 @@ //! Stream encoders. use futures_core::Stream; -use http::{header, Response, StatusCode, Version}; +use http::{header, Response, StatusCode}; use super::{ coder::{Coder, FeaturedCode}, @@ -56,13 +56,13 @@ where } #[cfg(any(feature = "br", feature = "gz", feature = "de"))] -fn update_header(headers: &mut header::HeaderMap, value: &'static str, version: Version) { +fn update_header(headers: &mut header::HeaderMap, value: &'static str, version: http::Version) { headers.insert(header::CONTENT_ENCODING, header::HeaderValue::from_static(value)); headers.remove(header::CONTENT_LENGTH); // Connection specific headers are not allowed in HTTP/2 and later versions. // see https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2 - if version < Version::HTTP_2 { + if version < http::Version::HTTP_2 { headers.insert(header::TRANSFER_ENCODING, header::HeaderValue::from_static("chunked")); } } diff --git a/http/Cargo.toml b/http/Cargo.toml index 959cbd311..d4d824cc8 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -37,14 +37,14 @@ io-uring = ["xitca-io/runtime-uring", "tokio-uring"] router = ["xitca-router"] [dependencies] -xitca-io = "0.4.0" +xitca-io = "0.4.2" xitca-service = { version = "0.3.0", features = ["alloc"] } xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] } futures-core = "0.3.17" http = "1" httpdate = "1.0" -pin-project-lite = "0.2.10" +pin-project-lite = "0.2.16" tracing = { version = "0.1.40", default-features = false } # native tls support @@ -80,7 +80,7 @@ tokio-uring = { version = "0.5.0", features = ["bytes"], optional = true } socket2 = { version = "0.6.0", features = ["all"] } [dev-dependencies] -criterion = "0.5" +criterion = "0.8.0" xitca-server = "0.5" [[bench]] diff --git a/http/benches/h1_decode.rs b/http/benches/h1_decode.rs index f780f85a1..f4c1cc0ce 100644 --- a/http/benches/h1_decode.rs +++ b/http/benches/h1_decode.rs @@ -1,6 +1,6 @@ -use std::time::SystemTime; +use std::{hint::black_box, time::SystemTime}; -use criterion::{Criterion, black_box, criterion_group, criterion_main}; +use criterion::{Criterion, criterion_group, criterion_main}; use httpdate::HttpDate; use tokio::time::Instant; use xitca_http::{ diff --git a/http/src/h1/body.rs b/http/src/h1/body.rs index f0e5585c5..59372015a 100644 --- a/http/src/h1/body.rs +++ b/http/src/h1/body.rs @@ -1,59 +1,38 @@ use core::{ - cell::{RefCell, RefMut}, - future::poll_fn, - ops::DerefMut, + fmt, pin::Pin, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; -use std::{collections::VecDeque, io, rc::Rc}; +use std::io; use futures_core::stream::Stream; use crate::bytes::Bytes; -/// max buffer size 32k -pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; - -#[derive(Clone, Debug)] -enum RequestBodyInner { - Some(Rc>), - #[cfg(feature = "io-uring")] - Completion(super::dispatcher_uring::Body), - None, -} - -impl RequestBodyInner { - fn new(eof: bool) -> Self { - match eof { - true => Self::None, - false => Self::Some(Default::default()), - } - } -} - /// Buffered stream of request body chunk. /// /// impl [Stream] trait to produce chunk as [Bytes] type in async manner. -#[derive(Debug)] -pub struct RequestBody(RequestBodyInner); +pub struct RequestBody(Option>>>>); -impl Default for RequestBody { - fn default() -> Self { - Self(RequestBodyInner::new(true)) +impl fmt::Debug for RequestBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("RequestBody") } } impl RequestBody { - // an async spsc channel where RequestBodySender used to push data and popped from RequestBody. - pub(super) fn channel(eof: bool) -> (RequestBodySender, Self) { - let inner = RequestBodyInner::new(eof); - (RequestBodySender(inner.clone()), RequestBody(inner)) + pub(super) fn new(body: S) -> Self + where + S: Stream> + 'static, + { + Self(Some(Box::pin(body))) } +} - #[cfg(feature = "io-uring")] - pub(super) fn io_uring(body: super::dispatcher_uring::Body) -> Self { - RequestBody(RequestBodyInner::Completion(body)) +impl Default for RequestBody { + fn default() -> Self { + Self(None) } } @@ -62,10 +41,8 @@ impl Stream for RequestBody { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { match self.get_mut().0 { - RequestBodyInner::Some(ref mut inner) => inner.borrow_mut().poll_next_unpin(cx), - RequestBodyInner::None => Poll::Ready(None), - #[cfg(feature = "io-uring")] - RequestBodyInner::Completion(ref mut body) => Pin::new(body).poll_next(cx), + None => Poll::Ready(None), + Some(ref mut body) => Pin::new(body).poll_next(cx), } } } @@ -75,182 +52,3 @@ impl From for crate::body::RequestBody { Self::H1(body) } } - -/// Sender part of the payload stream -pub struct RequestBodySender(RequestBodyInner); - -// TODO: rework early eof error handling. -impl Drop for RequestBodySender { - fn drop(&mut self) { - if let Some(mut inner) = self.try_inner() { - if !inner.eof { - inner.feed_error(io::ErrorKind::UnexpectedEof.into()); - } - } - } -} - -impl RequestBodySender { - // try to get a mutable reference of inner and ignore RequestBody::None variant. - fn try_inner(&mut self) -> Option> { - self.try_inner_on_none_with(|| {}) - } - - // try to get a mutable reference of inner and panic on RequestBody::None variant. - // this is a runtime check for internal optimization to avoid unnecessary operations. - // public api must not be able to trigger this panic. - fn try_inner_infallible(&mut self) -> Option> { - self.try_inner_on_none_with(|| panic!("No Request Body found. Do not waste operation on Sender.")) - } - - fn try_inner_on_none_with(&mut self, func: F) -> Option> - where - F: FnOnce(), - { - match self.0 { - RequestBodyInner::Some(ref inner) => { - // request body is a shared pointer between only two owners and no weak reference. - debug_assert!(Rc::strong_count(inner) <= 2); - debug_assert_eq!(Rc::weak_count(inner), 0); - (Rc::strong_count(inner) != 1).then_some(inner.borrow_mut()) - } - _ => { - func(); - None - } - } - } - - pub(super) fn feed_error(&mut self, e: io::Error) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_error(e); - } - } - - pub(super) fn feed_eof(&mut self) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_eof(); - } - } - - pub(super) fn feed_data(&mut self, data: Bytes) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_data(data); - } - } - - pub(super) fn ready(&mut self) -> impl Future> + '_ { - self.ready_with(|inner| !inner.backpressure()) - } - - // Lazily wait until RequestBody is already polled. - // For specific use case body must not be eagerly polled. - // For example: Request with Expect: Continue header. - pub(super) fn wait_for_poll(&mut self) -> impl Future> + '_ { - self.ready_with(|inner| inner.waiting()) - } - - async fn ready_with(&mut self, func: F) -> io::Result<()> - where - F: Fn(&mut Inner) -> bool, - { - poll_fn(|cx| { - // Check only if Payload (other side) is alive, Otherwise always return io error. - match self.try_inner_infallible() { - Some(mut inner) => { - if func(inner.deref_mut()) { - Poll::Ready(Ok(())) - } else { - // when payload is not ready register current task waker and wait. - inner.register_io(cx); - Poll::Pending - } - } - None => Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), - } - }) - .await - } -} - -#[derive(Debug, Default)] -struct Inner { - eof: bool, - len: usize, - err: Option, - items: VecDeque, - task: Option, - io_task: Option, -} - -impl Inner { - /// Wake up future waiting for payload data to be available. - fn wake(&mut self) { - if let Some(waker) = self.task.take() { - waker.wake(); - } - } - - /// Wake up future feeding data to Payload. - fn wake_io(&mut self) { - if let Some(waker) = self.io_task.take() { - waker.wake(); - } - } - - /// true when a future is waiting for payload data. - fn waiting(&self) -> bool { - self.task.is_some() - } - - /// Register future waiting data from payload. - /// Waker would be used in `Inner::wake` - fn register(&mut self, cx: &Context<'_>) { - if self.task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) { - self.task = Some(cx.waker().clone()); - } - } - - // Register future feeding data to payload. - /// Waker would be used in `Inner::wake_io` - fn register_io(&mut self, cx: &Context<'_>) { - if self.io_task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) { - self.io_task = Some(cx.waker().clone()); - } - } - - fn feed_error(&mut self, err: io::Error) { - self.err = Some(err); - self.wake(); - } - - fn feed_eof(&mut self) { - self.eof = true; - self.wake(); - } - - fn feed_data(&mut self, data: Bytes) { - self.len += data.len(); - self.items.push_back(data); - self.wake(); - } - - fn backpressure(&self) -> bool { - self.len >= MAX_BUFFER_SIZE - } - - fn poll_next_unpin(&mut self, cx: &Context<'_>) -> Poll>> { - if let Some(data) = self.items.pop_front() { - self.len -= data.len(); - Poll::Ready(Some(Ok(data))) - } else if let Some(err) = self.err.take() { - Poll::Ready(Some(Err(err))) - } else if self.eof { - Poll::Ready(None) - } else { - self.register(cx); - self.wake_io(); - Poll::Pending - } - } -} diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index 76793f2d5..3e5fcfb8c 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -1,35 +1,32 @@ use core::{ - convert::Infallible, - future::{pending, poll_fn}, + cell::RefCell, + future::poll_fn, marker::PhantomData, + mem, net::SocketAddr, pin::{Pin, pin}, + task::{self, Poll, Waker, ready}, time::Duration, }; -use std::io; +use std::{io, rc::Rc}; use futures_core::stream::Stream; +use pin_project_lite::pin_project; use tracing::trace; -use xitca_io::io::{AsyncIo, Interest, Ready}; +use xitca_io::io::{AsyncIo, Interest}; use xitca_service::Service; -use xitca_unsafe_collection::futures::{Select as _, SelectOutput}; +use xitca_unsafe_collection::futures::SelectOutput; use crate::{ body::NoneBody, - bytes::{Bytes, EitherBuf}, + bytes::{Bytes, BytesMut, EitherBuf}, config::HttpServiceConfig, date::DateTime, - h1::{ - body::{RequestBody, RequestBodySender}, - error::Error, - }, - http::{ - StatusCode, - response::{Parts, Response}, - }, + h1::{body::RequestBody, error::Error, proto::encode::CONTINUE_BYTES}, + http::{StatusCode, response::Response}, util::{ - buffered::{BufferedIo, ListWriteBuf, ReadBuf, WriteBuf}, + buffered::{ListWriteBuf, WriteBuf}, timer::{KeepAlive, Timeout}, }, }; @@ -38,7 +35,6 @@ use super::proto::{ buf_write::H1BufWrite, codec::{ChunkResult, TransferCoding}, context::Context, - encode::CONTINUE, error::ProtoError, }; @@ -57,7 +53,7 @@ pub(crate) async fn run< const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize, >( - io: &'a mut St, + io: St, addr: SocketAddr, timer: Pin<&'a mut KeepAlive>, config: HttpServiceConfig, @@ -68,10 +64,11 @@ where S: Service, Response = Response>, ReqB: From, ResB: Stream>, - St: AsyncIo, + St: AsyncIo + 'static, + for<'i> &'i St: AsyncIo, D: DateTime, { - let write_buf = if config.vectored_write && io.is_vectored_write() { + let write_buf = if config.vectored_write && (&io).is_vectored_write() { EitherBuf::Left(ListWriteBuf::<_, WRITE_BUF_LIMIT>::default()) } else { EitherBuf::Right(WriteBuf::::default()) @@ -83,102 +80,44 @@ where } /// Http/1 dispatcher -struct Dispatcher<'a, St, S, ReqB, W, D, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize> { - io: BufferedIo<'a, St, W, READ_BUF_LIMIT>, +struct Dispatcher<'a, Io, S, ReqB, W, D, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize> { + io: Rc, + read_buf: BytesMut, + write_buf: W, timer: Timer<'a>, ctx: Context<'a, D, HEADER_LIMIT>, + notify: Notify, service: &'a S, _phantom: PhantomData, } -// timer state is transformed in following order: -// -// Idle (expecting keep-alive duration) <-- -// | | -// --> Wait (expecting request head duration) | -// | | -// --> Throttle (expecting manually set to Idle again) -enum TimerState { - Idle, - Wait, - Throttle, -} - -pub(super) struct Timer<'a> { - timer: Pin<&'a mut KeepAlive>, - state: TimerState, - ka_dur: Duration, - req_dur: Duration, -} - -impl<'a> Timer<'a> { - pub(super) fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self { - Self { - timer, - state: TimerState::Idle, - ka_dur, - req_dur, - } - } - - pub(super) fn reset_state(&mut self) { - self.state = TimerState::Idle; - } - - pub(super) fn get(&mut self) -> Pin<&mut KeepAlive> { - self.timer.as_mut() - } - - // update timer with a given base instant value. the final deadline is calculated base on it. - pub(super) fn update(&mut self, now: tokio::time::Instant) { - let dur = match self.state { - TimerState::Idle => { - self.state = TimerState::Wait; - self.ka_dur - } - TimerState::Wait => { - self.state = TimerState::Throttle; - self.req_dur - } - TimerState::Throttle => return, - }; - self.timer.as_mut().update(now + dur) - } - - #[cold] - #[inline(never)] - pub(super) fn map_to_err(&self) -> Error { - match self.state { - TimerState::Wait => Error::KeepAliveExpire, - TimerState::Throttle => Error::RequestTimeout, - TimerState::Idle => unreachable!(), - } - } -} - -impl<'a, St, S, ReqB, ResB, BE, W, D, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize> - Dispatcher<'a, St, S, ReqB, W, D, HEADER_LIMIT, READ_BUF_LIMIT> +impl<'a, St, S, ReqB, ResB, BE, W, D, const H_LIMIT: usize, const R_LIMIT: usize> + Dispatcher<'a, St, S, ReqB, W, D, H_LIMIT, R_LIMIT> where S: Service, Response = Response>, ReqB: From, ResB: Stream>, - St: AsyncIo, + St: AsyncIo + 'static, + for<'i> &'i St: AsyncIo, W: H1BufWrite, D: DateTime, { - fn new( - io: &'a mut St, + fn new( + io: St, addr: SocketAddr, timer: Pin<&'a mut KeepAlive>, - config: HttpServiceConfig, + config: HttpServiceConfig, service: &'a S, date: &'a D, write_buf: W, ) -> Self { Self { - io: BufferedIo::new(io, write_buf), + io: Rc::new(io), + read_buf: BytesMut::new(), + write_buf, timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), ctx: Context::with_addr(addr, date), + notify: Notify::new(), service, _phantom: PhantomData, } @@ -187,132 +126,147 @@ where async fn run(mut self) -> Result<(), Error> { loop { if let Err(err) = self._run().await { - handle_error(&mut self.ctx, &mut self.io.write_buf, err)?; + handle_error(&mut self.ctx, &mut self.write_buf, err)?; } // TODO: add timeout for drain write? - self.io.drain_write().await?; + write(&*self.io, &mut self.write_buf).await?; if self.ctx.is_connection_closed() { - return self.io.shutdown().await.map_err(Into::into); + return self.shutdown().await.map_err(Into::into); } } } async fn _run(&mut self) -> Result<(), Error> { self.timer.update(self.ctx.date().now()); - self.io - .read() + + let read = read(&*self.io, &mut self.read_buf) .timeout(self.timer.get()) .await .map_err(|_| self.timer.map_to_err())??; - while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.io.read_buf)? { + if read == 0 { + self.ctx.set_close(); + return Ok(()); + } + + while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.read_buf)? { self.timer.reset_state(); - let (mut body_reader, body) = BodyReader::from_coding(decoder); + let (waiter, body) = if decoder.is_eof() { + (None, RequestBody::default()) + } else { + let body = body( + self.io.clone(), + self.ctx.is_expect_header(), + R_LIMIT, + decoder, + mem::take(&mut self.read_buf), + self.notify.notifier(), + ); + + (Some(&mut self.notify), body) + }; + let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); - let (parts, body) = match self - .service - .call(req) - .select(self.request_body_handler(&mut body_reader)) - .await - { - SelectOutput::A(Ok(res)) => res.into_parts(), - SelectOutput::A(Err(e)) => return Err(Error::Service(e)), - SelectOutput::B(Err(e)) => return Err(e), - SelectOutput::B(Ok(i)) => match i {}, - }; + let (parts, body) = self.service.call(req).await.map_err(Error::Service)?.into_parts(); - let encoder = &mut self.encode_head(parts, &body)?; - let mut body = pin!(body); + let mut encoder = self.ctx.encode_head(parts, &body, &mut self.write_buf)?; - loop { - match self - .try_poll_body(body.as_mut()) - .select(self.io_ready(&mut body_reader)) - .await - { - SelectOutput::A(Some(Ok(bytes))) => encoder.encode(bytes, &mut self.io.write_buf), - SelectOutput::B(Ok(ready)) => { - if ready.is_readable() { - if let Err(e) = self.io.try_read() { - body_reader.feed_error(e); + // this block is necessary. ResB has to be dropped asap as it may hold ownership of + // Body type which if not dropped before Notifier::notify is called would prevent + // Notifier from waking up Notify. + { + let mut body = pin!(body); + + loop { + let buf = &mut self.write_buf; + + let res = poll_fn(|cx| match body.as_mut().poll_next(cx) { + Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)), + Poll::Pending if buf.want_write_io() => Poll::Pending, + Poll::Pending => Poll::Ready(SelectOutput::B(())), + }) + .await; + + match res { + SelectOutput::A(Some(Ok(bytes))) => { + encoder.encode(bytes, buf); + if buf.want_write_buf() { + continue; } } - if ready.is_writable() { - self.io.try_write()?; - } - } - SelectOutput::A(None) => { - encoder.encode_eof(&mut self.io.write_buf); - break; + SelectOutput::A(Some(Err(e))) => return self.on_body_error(e).await, + SelectOutput::A(None) => break encoder.encode_eof(buf), + SelectOutput::B(_) => {} } - SelectOutput::B(Err(e)) => return Err(e.into()), - SelectOutput::A(Some(Err(e))) => return Err(Error::Body(e)), + + write(&*self.io, buf).await?; } } - if !body_reader.decoder.is_eof() { - self.ctx.set_close(); - break; + if let Some(waiter) = waiter { + match waiter.wait().await { + Some(read_buf) => self.read_buf = read_buf, + None => { + self.ctx.set_close(); + break; + } + } } } Ok(()) } - fn encode_head(&mut self, parts: Parts, body: &impl Stream) -> Result { - self.ctx.encode_head(parts, body, &mut self.io.write_buf) + #[cold] + #[inline(never)] + async fn shutdown(self) -> Result<(), Error> { + let mut io = Rc::try_unwrap(self.io) + .ok() + .expect("Dispatcher must have exclusive ownership to Io when closing connection"); + + poll_fn(|cx| Pin::new(&mut io).poll_shutdown(cx)) + .await + .map_err(Into::into) } - // an associated future of self.service that runs until service is resolved or error produced. - async fn request_body_handler(&mut self, body_reader: &mut BodyReader) -> Result> { - if self.ctx.is_expect_header() { - // wait for service future to start polling RequestBody. - if body_reader.wait_for_poll().await.is_ok() { - // encode continue as service future want a body. - self.io.write_buf.write_buf_static(CONTINUE); - // use drain write to make sure continue is sent to client. - self.io.drain_write().await?; - } - } + #[cold] + #[inline(never)] + async fn on_body_error(&mut self, e: BE) -> Result<(), Error> { + write(&*self.io, &mut self.write_buf).await?; + Err(Error::Body(e)) + } +} - loop { - body_reader.ready(&mut self.io.read_buf).await; - self.io.read().await?; +async fn read(mut io: impl AsyncIo, buf: &mut BytesMut) -> io::Result { + loop { + if buf.len() == buf.capacity() { + buf.reserve(4096); } - } - fn try_poll_body<'b>(&self, mut body: Pin<&'b mut ResB>) -> impl Future>> + 'b { - let want_buf = self.io.write_buf.want_write_buf(); - async move { - if want_buf { - poll_fn(|cx| body.as_mut().poll_next(cx)).await - } else { - pending().await - } + io.ready(Interest::READABLE).await?; + + match xitca_unsafe_collection::bytes::read_buf(&mut io, buf) { + Ok(n) => return Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Err(e), } } +} - // Check readable and writable state of BufferedIo and ready state of request body reader. - // return error when runtime is shutdown.(See AsyncIo::ready for reason). - async fn io_ready(&mut self, body_reader: &mut BodyReader) -> io::Result { - if !self.io.write_buf.want_write_io() { - body_reader.ready(&mut self.io.read_buf).await; - self.io.io.ready(Interest::READABLE).await - } else { - match body_reader - .ready(&mut self.io.read_buf) - .select(self.io.io.ready(Interest::WRITABLE)) - .await - { - SelectOutput::A(_) => self.io.io.ready(Interest::READABLE | Interest::WRITABLE).await, - SelectOutput::B(res) => res, - } - } +async fn write(mut io: impl AsyncIo, buf: &mut W) -> io::Result<()> +where + W: H1BufWrite, +{ + while buf.want_write_io() { + io.ready(Interest::WRITABLE).await?; + buf.do_io(&mut io)?; } + + Ok(()) } #[cold] @@ -350,50 +304,258 @@ where Ok(()) } -pub(super) struct BodyReader { - pub(super) decoder: TransferCoding, - tx: RequestBodySender, +fn body( + io: Rc, + is_expect: bool, + limit: usize, + decoder: TransferCoding, + read_buf: BytesMut, + notify: Notifier, +) -> RequestBody +where + Io: 'static, + for<'i> &'i Io: AsyncIo, +{ + let body = BodyInner { + io, + decoder: Decoder { + decoder, + limit, + read_buf, + notify, + }, + }; + + let state = if is_expect { + State::ExpectWrite { + fut: async { + write(&*body.io, &mut BytesMut::from(CONTINUE_BYTES)) + .await + .map(|_| body) + }, + } + } else { + State::Body { body } + }; + + RequestBody::new(BodyReader { state }) } -impl BodyReader { - pub(super) fn from_coding(decoder: TransferCoding) -> (Self, RequestBody) { - let (tx, body) = RequestBody::channel(decoder.is_eof()); - let body_reader = BodyReader { decoder, tx }; - (body_reader, body) +pin_project! { + #[project = StateProj] + #[project_replace = StateProjReplace] + enum State { + Body { + body: BodyInner + }, + ExpectWrite { + #[pin] + fut: FutE, + }, + None, } +} + +pin_project! { + struct BodyReader { + #[pin] + state: State + } +} - // dispatcher MUST call this method before do any io reading. - // a none ready state means the body consumer either is in backpressure or don't expect body. - pub(super) async fn ready(&mut self, read_buf: &mut ReadBuf) { +struct BodyInner { + io: Rc, + decoder: Decoder, +} + +impl Stream for BodyReader +where + for<'i> &'i Io: AsyncIo, + FutE: Future>>, +{ + type Item = io::Result; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut this = self.project(); loop { - match self.decoder.decode(&mut *read_buf) { - ChunkResult::Ok(bytes) => self.tx.feed_data(bytes), - ChunkResult::InsufficientData => match self.tx.ready().await { - Ok(_) => return, - // service future drop RequestBody so marker decoder to corrupted. - Err(_) => self.decoder.set_corrupted(), - }, - ChunkResult::OnEof => self.tx.feed_eof(), - ChunkResult::AlreadyEof | ChunkResult::Corrupted => pending().await, - ChunkResult::Err(e) => self.feed_error(e), + match this.state.as_mut().project() { + StateProj::Body { body } => { + let mut io = &*body.io; + let decoder = &mut body.decoder; + + match decoder.decode() { + ChunkResult::Ok(bytes) => return Poll::Ready(Some(Ok(bytes))), + ChunkResult::Err(e) => return Poll::Ready(Some(Err(e))), + ChunkResult::InsufficientData => decoder.limit_check()?, + _ => return Poll::Ready(None), + } + + ready!(io.poll_ready(Interest::READABLE, cx))?; + + match xitca_unsafe_collection::bytes::read_buf(&mut io, &mut decoder.read_buf) { + Ok(n) => { + if n == 0 { + this.state.as_mut().project_replace(State::None); + return Poll::Ready(None); + } + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + StateProj::ExpectWrite { fut } => { + let body = ready!(fut.poll(cx))?; + this.state.as_mut().project_replace(State::Body { body }); + } + StateProj::None => return Poll::Ready(None), } } } +} + +pub(super) struct Timer<'a> { + timer: Pin<&'a mut KeepAlive>, + state: TimerState, + ka_dur: Duration, + req_dur: Duration, +} + +// timer state is transformed in following order: +// +// Idle (expecting keep-alive duration) <-- +// | | +// --> Wait (expecting request head duration) | +// | | +// --> Throttle (expecting manually set to Idle again) +enum TimerState { + Idle, + Wait, + Throttle, +} + +impl<'a> Timer<'a> { + pub(super) fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self { + Self { + timer, + state: TimerState::Idle, + ka_dur, + req_dur, + } + } + + pub(super) fn reset_state(&mut self) { + self.state = TimerState::Idle; + } + + pub(super) fn get(&mut self) -> Pin<&mut KeepAlive> { + self.timer.as_mut() + } + + // update timer with a given base instant value. the final deadline is calculated base on it. + pub(super) fn update(&mut self, now: tokio::time::Instant) { + let dur = match self.state { + TimerState::Idle => { + self.state = TimerState::Wait; + self.ka_dur + } + TimerState::Wait => { + self.state = TimerState::Throttle; + self.req_dur + } + TimerState::Throttle => return, + }; + self.timer.as_mut().update(now + dur) + } - // feed error to body sender and prepare for close connection. #[cold] #[inline(never)] - pub(super) fn feed_error(&mut self, e: io::Error) { - self.tx.feed_error(e); - self.decoder.set_corrupted(); + pub(super) fn map_to_err(&self) -> Error { + match self.state { + TimerState::Wait => Error::KeepAliveExpire, + TimerState::Throttle => Error::RequestTimeout, + TimerState::Idle => unreachable!(), + } } +} - // wait for service start to consume RequestBody. - pub(super) async fn wait_for_poll(&mut self) -> io::Result<()> { - // IMPORTANT: service future drop RequestBody so marker decoder to corrupted. - self.tx - .wait_for_poll() - .await - .inspect_err(|_| self.decoder.set_corrupted()) +pub(super) struct Decoder { + pub(super) decoder: TransferCoding, + pub(super) limit: usize, + pub(super) read_buf: BytesMut, + pub(super) notify: Notifier, +} + +impl Decoder { + pub(super) fn decode(&mut self) -> ChunkResult { + self.decoder.decode(&mut self.read_buf) + } + + pub(super) fn limit_check(&self) -> io::Result<()> { + if self.read_buf.len() >= self.limit { + let msg = format!( + "READ_BUF_LIMIT reached: {{ limit: {}, length: {} }}", + self.limit, + self.read_buf.len() + ); + Err(io::Error::other(msg)) + } else { + Ok(()) + } + } +} + +impl Drop for Decoder { + fn drop(&mut self) { + if self.decoder.is_eof() { + let buf = mem::take(&mut self.read_buf); + self.notify.notify(buf); + } + } +} + +pub(super) struct Notify(Rc>>); + +impl Notify { + pub(super) fn new() -> Self { + Self(Rc::new(RefCell::new(Inner { waker: None, val: None }))) + } + + pub(super) fn notifier(&mut self) -> Notifier { + Notifier(self.0.clone()) + } + + pub(super) fn wait(&mut self) -> impl Future> + '_ { + poll_fn(|cx| { + let mut inner = self.0.borrow_mut(); + if let Some(val) = inner.val.take() { + return Poll::Ready(Some(val)); + } else if Rc::strong_count(&self.0) == 1 { + return Poll::Ready(None); + } + inner.waker = Some(cx.waker().clone()); + Poll::Pending + }) + } +} + +pub(super) struct Notifier(Rc>>); + +impl Drop for Notifier { + fn drop(&mut self) { + if let Some(waker) = self.0.borrow_mut().waker.take() { + waker.wake(); + } } } + +impl Notifier { + pub(super) fn notify(&mut self, val: T) { + self.0.borrow_mut().val = Some(val); + } +} + +struct Inner { + waker: Option, + val: Option, +} diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs index 46694a35a..37b6ddead 100644 --- a/http/src/h1/dispatcher_uring.rs +++ b/http/src/h1/dispatcher_uring.rs @@ -1,13 +1,10 @@ use core::{ - cell::RefCell, - fmt, future::poll_fn, marker::PhantomData, mem, net::SocketAddr, - ops::{Deref, DerefMut}, pin::{Pin, pin}, - task::{self, Poll, Waker, ready}, + task::{self, Poll, ready}, }; use std::{io, net::Shutdown, rc::Rc}; @@ -25,13 +22,14 @@ use crate::{ bytes::Bytes, config::HttpServiceConfig, date::DateTime, - h1::{body::RequestBody, error::Error}, http::response::Response, util::timer::{KeepAlive, Timeout}, }; use super::{ - dispatcher::{Timer, handle_error}, + body::RequestBody, + dispatcher::{Decoder, Notifier, Notify, Timer, handle_error}, + error::Error, proto::{ codec::{ChunkResult, TransferCoding}, context::Context, @@ -47,48 +45,32 @@ pub(super) struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_L timer: Timer<'a>, ctx: Context<'a, D, H_LIMIT>, service: &'a S, - read_buf: BufOwned, - write_buf: BufOwned, - notify: Notify, + read_buf: BytesMut, + write_buf: BytesMut, + notify: Notify, _phantom: PhantomData, } -#[derive(Default)] -struct BufOwned { - buf: BytesMut, -} - -impl Deref for BufOwned { - type Target = BytesMut; - - fn deref(&self) -> &Self::Target { - &self.buf +trait BufIo: Default + Sized { + fn take(&mut self) -> Self { + mem::take(self) } -} -impl DerefMut for BufOwned { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.buf + fn set(&mut self, other: Self) { + *self = other; } -} -impl BufOwned { - fn new() -> Self { - Self { buf: BytesMut::new() } - } - - fn take(&mut self) -> BytesMut { - mem::replace(&mut self.buf, BytesMut::new()) - } + async fn read(&mut self, io: &impl AsyncBufRead) -> io::Result; - fn set(&mut self, buf: BytesMut) { - self.buf = buf; - } + async fn write(&mut self, io: &impl AsyncBufWrite) -> io::Result<()>; +} +impl BufIo for BytesMut { async fn read(&mut self, io: &impl AsyncBufRead) -> io::Result { let mut buf = self.take(); let len = buf.len(); + buf.reserve(4096); let (res, buf) = io.read(buf.slice(len..)).await; @@ -127,15 +109,15 @@ where timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), ctx: Context::with_addr(addr, date), service, - read_buf: BufOwned::new(), - write_buf: BufOwned::new(), + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), notify: Notify::new(), _phantom: PhantomData, }; loop { if let Err(err) = dispatcher._run().await { - handle_error(&mut dispatcher.ctx, &mut *dispatcher.write_buf, err)?; + handle_error(&mut dispatcher.ctx, &mut dispatcher.write_buf, err)?; } dispatcher.write_buf.write(&*dispatcher.io).await?; @@ -167,7 +149,7 @@ where let (waiter, body) = if decoder.is_eof() { (None, RequestBody::default()) } else { - let body = Body::new( + let body = body( self.io.clone(), self.ctx.is_expect_header(), R_LIMIT, @@ -176,14 +158,14 @@ where self.notify.notifier(), ); - (Some(&mut self.notify), RequestBody::io_uring(body)) + (Some(&mut self.notify), body) }; let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); let (parts, body) = self.service.call(req).await.map_err(Error::Service)?.into_parts(); - let mut encoder = self.ctx.encode_head(parts, &body, &mut *self.write_buf)?; + let mut encoder = self.ctx.encode_head(parts, &body, &mut self.write_buf)?; // this block is necessary. ResB has to be dropped asap as it may hold ownership of // Body type which if not dropped before Notifier::notify is called would prevent @@ -192,7 +174,7 @@ where let mut body = pin!(body); loop { - let buf = &mut *self.write_buf; + let buf = &mut self.write_buf; let res = poll_fn(|cx| match body.as_mut().poll_next(cx) { Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)), @@ -249,64 +231,39 @@ where } } -pub(super) struct Body(Pin>>>); - -impl Body { - fn new( - io: Rc, - is_expect: bool, - limit: usize, - decoder: TransferCoding, - read_buf: BufOwned, - notify: Notifier, - ) -> Self - where - Io: AsyncBufRead + AsyncBufWrite + 'static, - { - let body = BodyInner { - io, - decoder: Decoder { - decoder, - limit, - read_buf, - notify, - }, - }; - - let state = if is_expect { - State::ExpectWrite { - fut: async { - let (res, _) = write_all(&*body.io, CONTINUE_BYTES).await; - res.map(|_| body) - }, - } - } else { - State::Body { body } - }; - - Self(Box::pin(BodyReader { chunk_read, state })) - } -} - -impl fmt::Debug for Body { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("Body") - } -} - -impl Clone for Body { - fn clone(&self) -> Self { - unimplemented!("rework body module so it does not force Clone on Body type.") - } -} +fn body( + io: Rc, + is_expect: bool, + limit: usize, + decoder: TransferCoding, + read_buf: BytesMut, + notify: Notifier, +) -> RequestBody +where + Io: AsyncBufRead + AsyncBufWrite + 'static, +{ + let body = BodyInner { + io, + decoder: Decoder { + decoder, + limit, + read_buf, + notify, + }, + }; -impl Stream for Body { - type Item = io::Result; + let state = if is_expect { + State::ExpectWrite { + fut: async { + let (res, _) = write_all(&*body.io, CONTINUE_BYTES).await; + res.map(|_| body) + }, + } + } else { + State::Body { body } + }; - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_next(cx) - } + RequestBody::new(BodyReader { chunk_read, state }) } pin_project! { @@ -364,22 +321,13 @@ where loop { match this.state.as_mut().project() { StateProj::Body { body } => { - match body.decoder.decoder.decode(&mut body.decoder.read_buf) { + match body.decoder.decode() { ChunkResult::Ok(bytes) => return Poll::Ready(Some(Ok(bytes))), ChunkResult::Err(e) => return Poll::Ready(Some(Err(e))), - ChunkResult::InsufficientData => {} + ChunkResult::InsufficientData => body.decoder.limit_check()?, _ => return Poll::Ready(None), } - if body.decoder.read_buf.len() >= body.decoder.limit { - let msg = format!( - "READ_BUF_LIMIT reached: {{ limit: {}, length: {} }}", - body.decoder.limit, - body.decoder.read_buf.len() - ); - return Poll::Ready(Some(Err(io::Error::other(msg)))); - } - let StateProjReplace::Body { body } = this.state.as_mut().project_replace(State::None) else { unreachable!() }; @@ -406,65 +354,3 @@ where } } } - -struct Decoder { - decoder: TransferCoding, - limit: usize, - read_buf: BufOwned, - notify: Notifier, -} - -impl Drop for Decoder { - fn drop(&mut self) { - if self.decoder.is_eof() { - let buf = mem::take(&mut self.read_buf); - self.notify.notify(buf); - } - } -} - -struct Notify(Rc>>); - -impl Notify { - fn new() -> Self { - Self(Rc::new(RefCell::new(Inner { waker: None, val: None }))) - } - - fn notifier(&mut self) -> Notifier { - Notifier(self.0.clone()) - } - - fn wait(&mut self) -> impl Future> + '_ { - poll_fn(|cx| { - let mut inner = self.0.borrow_mut(); - if let Some(val) = inner.val.take() { - return Poll::Ready(Some(val)); - } else if Rc::strong_count(&self.0) == 1 { - return Poll::Ready(None); - } - inner.waker = Some(cx.waker().clone()); - Poll::Pending - }) - } -} - -struct Notifier(Rc>>); - -impl Drop for Notifier { - fn drop(&mut self) { - if let Some(waker) = self.0.borrow_mut().waker.take() { - waker.wake(); - } - } -} - -impl Notifier { - fn notify(&mut self, val: T) { - self.0.borrow_mut().val = Some(val); - } -} - -struct Inner { - waker: Option, - val: Option, -} diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index 91c7cf644..a31d9f23a 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -23,7 +23,8 @@ where S: Service>, Response = Response>, A: Service, St: AsyncIo, - A::Response: AsyncIo, + A::Response: AsyncIo + 'static, + for<'i> &'i A::Response: AsyncIo, B: Stream>, HttpServiceError: From, { @@ -34,14 +35,14 @@ where // at this stage keep-alive timer is used to tracks tls accept timeout. let mut timer = pin!(self.keep_alive()); - let mut io = self + let io = self .tls_acceptor .call(io) .timeout(timer.as_mut()) .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get()) + super::dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) .await .map_err(Into::into) } diff --git a/http/src/service.rs b/http/src/service.rs index 703a814bf..eded0deda 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -74,7 +74,8 @@ impl>, Response = Response>, A: Service, - A::Response: AsyncIo + AsVersion, + A::Response: AsyncIo + AsVersion + 'static, + for<'i> &'i A::Response: AsyncIo, HttpServiceError: From, S::Error: fmt::Debug, ResB: Stream>, @@ -114,7 +115,7 @@ where match version { #[cfg(feature = "http1")] super::http::Version::HTTP_11 | super::http::Version::HTTP_10 => super::h1::dispatcher::run( - &mut _tls_stream, + _tls_stream, _addr, timer.as_mut(), self.config, @@ -159,10 +160,11 @@ where #[cfg(feature = "http1")] { - let mut io = xitca_io::net::UnixStream::from_std(_io).expect("TODO: handle io error"); + let io = xitca_io::net::UnixStream::from_std(_io).expect("TODO: handle io error"); - super::h1::dispatcher::run( - &mut io, + // TODO: this is a rust compiler regression where the function fail to infer the stream type. remove type annotation when it's fixed + super::h1::dispatcher::run::( + io, crate::unspecified_socket_addr(), timer.as_mut(), self.config, diff --git a/io/CHANGES.md b/io/CHANGES.md index 47b16da1f..08ff9c429 100644 --- a/io/CHANGES.md +++ b/io/CHANGES.md @@ -1,4 +1,6 @@ # unreleased 0.4.2 +## Add +- add `AsyncIo` impl to `&TcpStream` and `&UnixStream` ## Fix - relax trait bound of `io_uring::write_all` diff --git a/io/src/net.rs b/io/src/net.rs index f7d335da5..d5bbe4296 100644 --- a/io/src/net.rs +++ b/io/src/net.rs @@ -51,6 +51,36 @@ macro_rules! default_aio_impl { } } + impl crate::io::AsyncIo for &$ty { + #[inline] + async fn ready(&mut self, interest: crate::io::Interest) -> ::std::io::Result { + self.0.ready(interest).await + } + + fn poll_ready( + &mut self, + interest: crate::io::Interest, + cx: &mut ::core::task::Context<'_>, + ) -> ::core::task::Poll<::std::io::Result> { + match interest { + crate::io::Interest::READABLE => self.0.poll_read_ready(cx).map_ok(|_| crate::io::Ready::READABLE), + crate::io::Interest::WRITABLE => self.0.poll_write_ready(cx).map_ok(|_| crate::io::Ready::WRITABLE), + _ => unimplemented!("tokio does not support poll_ready for BOTH read and write ready"), + } + } + + fn is_vectored_write(&self) -> bool { + crate::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_shutdown( + self: ::core::pin::Pin<&mut Self>, + _: &mut ::core::task::Context<'_>, + ) -> ::core::task::Poll<::std::io::Result<()>> { + unimplemented!("poll_shutdown can not be performed from non exclusive reference of IO type"); + } + } + impl ::std::io::Read for $ty { #[inline] fn read(&mut self, buf: &mut [u8]) -> ::std::io::Result {