From 2a1f6a4540d07f529a7cc712fefa880ea9c0447b Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Mon, 29 Dec 2025 16:40:21 +0800 Subject: [PATCH 1/2] experiment async copy in with backpressure --- postgres/src/client.rs | 18 ++- postgres/src/driver/codec.rs | 9 +- postgres/src/driver/codec/encode.rs | 9 +- postgres/src/driver/generic.rs | 220 +++++++--------------------- 4 files changed, 76 insertions(+), 180 deletions(-) diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 39bf6fd96..f73c7a360 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -5,16 +5,14 @@ use std::{ sync::{Arc, Mutex}, }; +use xitca_io::bytes::BytesMut; use xitca_unsafe_collection::no_hash::NoHashBuilder; use super::{ copy::{CopyIn, CopyOut}, driver::{ DriverTx, - codec::{ - Response, - encode::{self, Encode}, - }, + codec::{Response, encode::Encode}, }, error::Error, query::Query, @@ -131,6 +129,7 @@ where /// [`Driver`]: crate::driver::Driver pub struct Client { pub(crate) tx: DriverTx, + pub(crate) buf: Mutex, pub(crate) cache: Box, } @@ -260,6 +259,7 @@ impl Client { pub(crate) fn new(tx: DriverTx, session: Session) -> Self { Self { tx, + buf: Mutex::new(BytesMut::with_capacity(4096)), cache: Box::new(ClientCache { session, type_info: Mutex::new(CachedTypeInfo { @@ -286,7 +286,15 @@ impl Query for Client { where S: Encode, { - encode::send_encode_query(&self.tx, stmt) + let (res1, buf) = { + let mut buf = self.buf.lock().unwrap(); + let len = buf.len(); + let res1 = stmt.encode(&mut buf).inspect_err(|_| buf.truncate(len))?; + (res1, buf.split()) + }; + + let res2 = self.tx.send(buf)?; + Ok((res1, res2)) } } diff --git a/postgres/src/driver/codec.rs b/postgres/src/driver/codec.rs index 9198c68a0..ad47723d4 100644 --- a/postgres/src/driver/codec.rs +++ b/postgres/src/driver/codec.rs @@ -15,12 +15,15 @@ use crate::{ types::BorrowToSql, }; -use super::DriverTx; +pub struct Request { + pub(super) tx: ResponseSender, + pub(super) buf: BytesMut, +} -pub(super) fn request_pair() -> (ResponseSender, Response) { +pub(super) fn request_pair(buf: BytesMut) -> (Request, Response) { let (tx, rx) = unbounded_channel(); ( - tx, + Request { tx, buf }, Response { rx, buf: BytesMut::new(), diff --git a/postgres/src/driver/codec/encode.rs b/postgres/src/driver/codec/encode.rs index fc1a39cab..5c330ac90 100644 --- a/postgres/src/driver/codec/encode.rs +++ b/postgres/src/driver/codec/encode.rs @@ -15,7 +15,7 @@ use crate::{ }; use super::{ - AsParams, DriverTx, Response, + AsParams, response::{ IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking, }, @@ -236,13 +236,6 @@ impl<'s> Encode for PortalQuery<'s> { } } -pub(crate) fn send_encode_query(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error> -where - S: Encode, -{ - tx.send(|buf| stmt.encode(buf)) -} - fn encode_bind

(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error> where P: AsParams, diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index 6b7cb7869..c743a4a6a 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -1,17 +1,12 @@ use core::{ future::{Future, poll_fn}, - ops::Deref, pin::Pin, - task::{Poll, Waker}, }; -use std::{ - collections::VecDeque, - io, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, io}; use postgres_protocol::message::{backend, frontend}; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; use xitca_io::{ bytes::{Buf, BytesMut}, io::{AsyncIo, Interest}, @@ -23,141 +18,54 @@ use crate::{ iter::AsyncLendingIterator, }; -use super::codec::{Response, ResponseMessage, ResponseSender}; +use super::codec::{Request, Response, ResponseMessage, ResponseSender}; type PagedBytesMut = xitca_unsafe_collection::bytes::PagedBytesMut<4096>; const INTEREST_READ_WRITE: Interest = Interest::READABLE.add(Interest::WRITABLE); -pub(crate) struct DriverTx(Arc); +pub(crate) struct DriverTx(UnboundedSender); impl Drop for DriverTx { fn drop(&mut self) { - let mut state = self.0.guarded.lock().unwrap(); - frontend::terminate(&mut state.buf); - state.closed = true; - state.wake(); + let mut buf = BytesMut::new(); + frontend::terminate(&mut buf); + let (tx, _) = super::codec::request_pair(buf); + let _ = self.0.send(tx); } } impl DriverTx { pub(crate) fn is_closed(&self) -> bool { - Arc::strong_count(&self.0) == 1 + self.0.is_closed() } pub(crate) fn send_one_way(&self, func: F) -> Result<(), Error> where F: FnOnce(&mut BytesMut) -> Result<(), Error>, { - self._send(func, |_| {})?; - Ok(()) - } - - pub(crate) fn send(&self, func: F) -> Result<(O, Response), Error> - where - F: FnOnce(&mut BytesMut) -> Result, - { - self._send(func, |inner| { - let (tx, rx) = super::codec::request_pair(); - inner.res.push_back(tx); - rx - }) - } - - fn _send(&self, func: F, on_send: F2) -> Result<(O, T), Error> - where - F: FnOnce(&mut BytesMut) -> Result, - F2: FnOnce(&mut State) -> T, - { - let mut inner = self.0.guarded.lock().unwrap(); - - if inner.closed { - return Err(DriverDown.into()); - } - - let len = inner.buf.len(); - - let o = func(&mut inner.buf).inspect_err(|_| inner.buf.truncate(len))?; - let t = on_send(&mut inner); - - inner.wake(); - - Ok((o, t)) - } -} - -pub(crate) struct DriverRx(Arc); - -impl Deref for DriverRx { - type Target = SharedState; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -// in case driver is dropped without closing the shared state -impl Drop for DriverRx { - fn drop(&mut self) { - self.guarded.lock().unwrap().closed = true; - } -} - -pub(crate) struct SharedState { - pub(super) guarded: Mutex, -} - -impl SharedState { - pub(super) fn wait(&self) -> impl Future + use<'_> { - poll_fn(|cx| { - let mut inner = self.guarded.lock().unwrap(); - if !inner.buf.is_empty() { - Poll::Ready(WaitState::WantWrite) - } else if inner.closed { - Poll::Ready(WaitState::WantClose) - } else { - inner.register(cx.waker()); - Poll::Pending - } - }) - } -} - -pub(super) enum WaitState { - WantWrite, - WantClose, -} - -pub(super) struct State { - pub(super) closed: bool, - pub(super) buf: BytesMut, - pub(super) res: VecDeque, - pub(super) waker: Option, -} - -impl State { - fn register(&mut self, waker: &Waker) { - self.waker = Some(waker.clone()); + todo!() } - fn wake(&mut self) { - if let Some(waker) = self.waker.take() { - waker.wake(); - } + pub(crate) fn send(&self, buf: BytesMut) -> Result { + let (req, res) = super::codec::request_pair(buf); + self.0.send(req).map_err(|_| DriverDown)?; + Ok(res) } } pub struct GenericDriver { pub(super) io: Io, + pub(super) req: VecDeque, pub(super) read_buf: PagedBytesMut, - pub(super) rx: DriverRx, + pub(super) rx: UnboundedReceiver, read_state: ReadState, write_state: WriteState, } enum WriteState { Waiting, - WantWrite, + WantWrite(BytesMut), WantFlush, Closed(Option), } @@ -172,30 +80,23 @@ where Io: AsyncIo + Send, { pub(crate) fn new(io: Io) -> (Self, DriverTx) { - let state = Arc::new(SharedState { - guarded: Mutex::new(State { - closed: false, - buf: BytesMut::new(), - res: VecDeque::new(), - waker: None, - }), - }); + let (tx, rx) = unbounded_channel(); ( Self { io, - rx: DriverRx(state.clone()), + rx, + req: VecDeque::new(), read_buf: PagedBytesMut::new(), read_state: ReadState::WantRead, write_state: WriteState::Waiting, }, - DriverTx(state), + DriverTx(tx), ) } pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> { - self.rx.guarded.lock().unwrap().buf.extend_from_slice(&msg); - self.write_state = WriteState::WantWrite; + self.write_state = WriteState::WantWrite(msg); loop { self.try_write()?; if matches!(self.write_state, WriteState::Waiting) { @@ -211,24 +112,24 @@ where async fn _try_next(&mut self) -> Result, Error> { loop { - if let Some(msg) = self.rx.try_decode(self.read_buf.get_mut())? { + if let Some(msg) = try_decode(&mut self.req, &mut self.read_buf.get_mut())? { return Ok(Some(msg)); } let res = match (&mut self.read_state, &mut self.write_state) { (ReadState::WantRead, WriteState::Waiting) => { - self.io.ready(Interest::READABLE).select(self.rx.wait()).await + self.io.ready(Interest::READABLE).select(self.rx.recv()).await } - (ReadState::WantRead, WriteState::WantWrite | WriteState::WantFlush) => { + (ReadState::WantRead, WriteState::WantWrite(_) | WriteState::WantFlush) => { SelectOutput::A(self.io.ready(INTEREST_READ_WRITE).await) } (ReadState::WantRead, WriteState::Closed(_)) => { SelectOutput::A(self.io.ready(Interest::READABLE).await) } - (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite) => { + (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite(_)) => { SelectOutput::A(self.io.ready(Interest::WRITABLE).await) } - (ReadState::Closed(_), WriteState::Waiting) => SelectOutput::B(self.rx.wait().await), + (ReadState::Closed(_), WriteState::Waiting) => SelectOutput::B(self.rx.recv().await), (ReadState::Closed(None), WriteState::Closed(None)) => { poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?; return Ok(None); @@ -255,8 +156,11 @@ where } } } - SelectOutput::B(WaitState::WantWrite) => self.write_state = WriteState::WantWrite, - SelectOutput::B(WaitState::WantClose) => self.write_state = WriteState::Closed(None), + SelectOutput::B(Some(msg)) => { + self.write_state = WriteState::WantWrite(msg.buf); + self.req.push_back(msg.tx); + } + SelectOutput::B(None) => self.write_state = WriteState::Closed(None), } } } @@ -292,31 +196,28 @@ where fn try_write(&mut self) -> io::Result<()> { debug_assert!( - matches!(self.write_state, WriteState::WantWrite | WriteState::WantFlush), + matches!(self.write_state, WriteState::WantWrite(_) | WriteState::WantFlush), "try_write must not be called when WriteState is Wait or Closed" ); - if matches!(self.write_state, WriteState::WantWrite) { - let mut inner = self.rx.guarded.lock().unwrap(); - + if let WriteState::WantWrite(ref mut buf) = self.write_state { let mut written = 0; - loop { - match io::Write::write(&mut self.io, &inner.buf[written..]) { + match io::Write::write(&mut self.io, &buf[written..]) { Ok(0) => { - inner.buf.advance(written); + buf.advance(written); return Err(io::ErrorKind::WriteZero.into()); } Ok(n) => { written += n; - if written == inner.buf.len() { - inner.buf.clear(); + if written == buf.len() { + buf.clear(); self.write_state = WriteState::WantFlush; break; } } Err(e) => { - inner.buf.advance(written); + buf.advance(written); return if matches!(e.kind(), io::ErrorKind::WouldBlock) { Ok(()) } else { @@ -345,40 +246,31 @@ where #[cold] #[inline(never)] fn on_write_err(&mut self, e: io::Error) { - { - // when write error occur the driver would go into half close state(read only). - // clearing write_buf would drop all pending requests in it and hint the driver - // no future Interest::WRITABLE should be passed to AsyncIo::ready method. - let mut inner = self.rx.guarded.lock().unwrap(); - inner.buf.clear(); - // close shared state early so driver tx can observe the shutdown in first hand - inner.closed = true; - } + // when write error occur the driver would go into half close state(read only). + // clearing all pending requests in channel to notify task waiting for response + self.rx.close(); self.write_state = WriteState::Closed(Some(e)); } } -impl DriverRx { - pub(super) fn try_decode(&self, read_buf: &mut BytesMut) -> Result, Error> { - let mut guard = None; - - while let Some(res) = ResponseMessage::try_from_buf(read_buf)? { - match res { - ResponseMessage::Normal(mut msg) => { - // lock the shared state only when needed and keep the lock around a bit for possible multiple messages - let inner = guard.get_or_insert_with(|| self.guarded.lock().unwrap()); - let res = inner.res.pop_front().ok_or_else(|| msg.parse_error())?; - let _ = res.send(msg.buf); - if !msg.complete { - inner.res.push_front(res); - } +pub(super) fn try_decode( + queue: &mut VecDeque, + read_buf: &mut BytesMut, +) -> Result, Error> { + while let Some(res) = ResponseMessage::try_from_buf(read_buf)? { + match res { + ResponseMessage::Normal(mut msg) => { + let req = queue.pop_front().ok_or_else(|| msg.parse_error())?; + let _ = req.send(msg.buf); + if !msg.complete { + queue.push_front(req); } - ResponseMessage::Async(msg) => return Ok(Some(msg)), } + ResponseMessage::Async(msg) => return Ok(Some(msg)), } - - Ok(None) } + + Ok(None) } impl AsyncLendingIterator for GenericDriver From 229865a3642b01d3ac279b480013406772cb342c Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Mon, 29 Dec 2025 16:53:59 +0800 Subject: [PATCH 2/2] maintain a separate write buffer in driver --- postgres/src/driver/generic.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index c743a4a6a..91785c660 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -58,6 +58,7 @@ pub struct GenericDriver { pub(super) io: Io, pub(super) req: VecDeque, pub(super) read_buf: PagedBytesMut, + pub(super) write_buf: BytesMut, pub(super) rx: UnboundedReceiver, read_state: ReadState, write_state: WriteState, @@ -65,7 +66,7 @@ pub struct GenericDriver { enum WriteState { Waiting, - WantWrite(BytesMut), + WantWrite, WantFlush, Closed(Option), } @@ -88,6 +89,7 @@ where rx, req: VecDeque::new(), read_buf: PagedBytesMut::new(), + write_buf: BytesMut::new(), read_state: ReadState::WantRead, write_state: WriteState::Waiting, }, @@ -96,7 +98,8 @@ where } pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> { - self.write_state = WriteState::WantWrite(msg); + self.write_buf.unsplit(msg); + self.write_state = WriteState::WantWrite; loop { self.try_write()?; if matches!(self.write_state, WriteState::Waiting) { @@ -120,13 +123,13 @@ where (ReadState::WantRead, WriteState::Waiting) => { self.io.ready(Interest::READABLE).select(self.rx.recv()).await } - (ReadState::WantRead, WriteState::WantWrite(_) | WriteState::WantFlush) => { - SelectOutput::A(self.io.ready(INTEREST_READ_WRITE).await) + (ReadState::WantRead, WriteState::WantWrite | WriteState::WantFlush) => { + self.io.ready(INTEREST_READ_WRITE).select(self.rx.recv()).await } (ReadState::WantRead, WriteState::Closed(_)) => { SelectOutput::A(self.io.ready(Interest::READABLE).await) } - (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite(_)) => { + (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite) => { SelectOutput::A(self.io.ready(Interest::WRITABLE).await) } (ReadState::Closed(_), WriteState::Waiting) => SelectOutput::B(self.rx.recv().await), @@ -157,7 +160,8 @@ where } } SelectOutput::B(Some(msg)) => { - self.write_state = WriteState::WantWrite(msg.buf); + self.write_buf.unsplit(msg.buf); + self.write_state = WriteState::WantWrite; self.req.push_back(msg.tx); } SelectOutput::B(None) => self.write_state = WriteState::Closed(None), @@ -196,28 +200,28 @@ where fn try_write(&mut self) -> io::Result<()> { debug_assert!( - matches!(self.write_state, WriteState::WantWrite(_) | WriteState::WantFlush), + matches!(self.write_state, WriteState::WantWrite | WriteState::WantFlush), "try_write must not be called when WriteState is Wait or Closed" ); - if let WriteState::WantWrite(ref mut buf) = self.write_state { + if matches!(self.write_state, WriteState::WantWrite) { let mut written = 0; loop { - match io::Write::write(&mut self.io, &buf[written..]) { + match io::Write::write(&mut self.io, &self.write_buf[written..]) { Ok(0) => { - buf.advance(written); + self.write_buf.advance(written); return Err(io::ErrorKind::WriteZero.into()); } Ok(n) => { written += n; - if written == buf.len() { - buf.clear(); + if written == self.write_buf.len() { + self.write_buf.clear(); self.write_state = WriteState::WantFlush; break; } } Err(e) => { - buf.advance(written); + self.write_buf.advance(written); return if matches!(e.kind(), io::ErrorKind::WouldBlock) { Ok(()) } else {