diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 39bf6fd96..f73c7a360 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -5,16 +5,14 @@ use std::{ sync::{Arc, Mutex}, }; +use xitca_io::bytes::BytesMut; use xitca_unsafe_collection::no_hash::NoHashBuilder; use super::{ copy::{CopyIn, CopyOut}, driver::{ DriverTx, - codec::{ - Response, - encode::{self, Encode}, - }, + codec::{Response, encode::Encode}, }, error::Error, query::Query, @@ -131,6 +129,7 @@ where /// [`Driver`]: crate::driver::Driver pub struct Client { pub(crate) tx: DriverTx, + pub(crate) buf: Mutex, pub(crate) cache: Box, } @@ -260,6 +259,7 @@ impl Client { pub(crate) fn new(tx: DriverTx, session: Session) -> Self { Self { tx, + buf: Mutex::new(BytesMut::with_capacity(4096)), cache: Box::new(ClientCache { session, type_info: Mutex::new(CachedTypeInfo { @@ -286,7 +286,15 @@ impl Query for Client { where S: Encode, { - encode::send_encode_query(&self.tx, stmt) + let (res1, buf) = { + let mut buf = self.buf.lock().unwrap(); + let len = buf.len(); + let res1 = stmt.encode(&mut buf).inspect_err(|_| buf.truncate(len))?; + (res1, buf.split()) + }; + + let res2 = self.tx.send(buf)?; + Ok((res1, res2)) } } diff --git a/postgres/src/driver/codec.rs b/postgres/src/driver/codec.rs index 9198c68a0..ad47723d4 100644 --- a/postgres/src/driver/codec.rs +++ b/postgres/src/driver/codec.rs @@ -15,12 +15,15 @@ use crate::{ types::BorrowToSql, }; -use super::DriverTx; +pub struct Request { + pub(super) tx: ResponseSender, + pub(super) buf: BytesMut, +} -pub(super) fn request_pair() -> (ResponseSender, Response) { +pub(super) fn request_pair(buf: BytesMut) -> (Request, Response) { let (tx, rx) = unbounded_channel(); ( - tx, + Request { tx, buf }, Response { rx, buf: BytesMut::new(), diff --git a/postgres/src/driver/codec/encode.rs b/postgres/src/driver/codec/encode.rs index fc1a39cab..5c330ac90 100644 --- a/postgres/src/driver/codec/encode.rs +++ b/postgres/src/driver/codec/encode.rs @@ -15,7 +15,7 @@ use crate::{ }; use super::{ - AsParams, DriverTx, Response, + AsParams, response::{ IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking, }, @@ -236,13 +236,6 @@ impl<'s> Encode for PortalQuery<'s> { } } -pub(crate) fn send_encode_query(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error> -where - S: Encode, -{ - tx.send(|buf| stmt.encode(buf)) -} - fn encode_bind

(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error> where P: AsParams, diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index 6b7cb7869..91785c660 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -1,17 +1,12 @@ use core::{ future::{Future, poll_fn}, - ops::Deref, pin::Pin, - task::{Poll, Waker}, }; -use std::{ - collections::VecDeque, - io, - sync::{Arc, Mutex}, -}; +use std::{collections::VecDeque, io}; use postgres_protocol::message::{backend, frontend}; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; use xitca_io::{ bytes::{Buf, BytesMut}, io::{AsyncIo, Interest}, @@ -23,134 +18,48 @@ use crate::{ iter::AsyncLendingIterator, }; -use super::codec::{Response, ResponseMessage, ResponseSender}; +use super::codec::{Request, Response, ResponseMessage, ResponseSender}; type PagedBytesMut = xitca_unsafe_collection::bytes::PagedBytesMut<4096>; const INTEREST_READ_WRITE: Interest = Interest::READABLE.add(Interest::WRITABLE); -pub(crate) struct DriverTx(Arc); +pub(crate) struct DriverTx(UnboundedSender); impl Drop for DriverTx { fn drop(&mut self) { - let mut state = self.0.guarded.lock().unwrap(); - frontend::terminate(&mut state.buf); - state.closed = true; - state.wake(); + let mut buf = BytesMut::new(); + frontend::terminate(&mut buf); + let (tx, _) = super::codec::request_pair(buf); + let _ = self.0.send(tx); } } impl DriverTx { pub(crate) fn is_closed(&self) -> bool { - Arc::strong_count(&self.0) == 1 + self.0.is_closed() } pub(crate) fn send_one_way(&self, func: F) -> Result<(), Error> where F: FnOnce(&mut BytesMut) -> Result<(), Error>, { - self._send(func, |_| {})?; - Ok(()) + todo!() } - pub(crate) fn send(&self, func: F) -> Result<(O, Response), Error> - where - F: FnOnce(&mut BytesMut) -> Result, - { - self._send(func, |inner| { - let (tx, rx) = super::codec::request_pair(); - inner.res.push_back(tx); - rx - }) - } - - fn _send(&self, func: F, on_send: F2) -> Result<(O, T), Error> - where - F: FnOnce(&mut BytesMut) -> Result, - F2: FnOnce(&mut State) -> T, - { - let mut inner = self.0.guarded.lock().unwrap(); - - if inner.closed { - return Err(DriverDown.into()); - } - - let len = inner.buf.len(); - - let o = func(&mut inner.buf).inspect_err(|_| inner.buf.truncate(len))?; - let t = on_send(&mut inner); - - inner.wake(); - - Ok((o, t)) - } -} - -pub(crate) struct DriverRx(Arc); - -impl Deref for DriverRx { - type Target = SharedState; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -// in case driver is dropped without closing the shared state -impl Drop for DriverRx { - fn drop(&mut self) { - self.guarded.lock().unwrap().closed = true; - } -} - -pub(crate) struct SharedState { - pub(super) guarded: Mutex, -} - -impl SharedState { - pub(super) fn wait(&self) -> impl Future + use<'_> { - poll_fn(|cx| { - let mut inner = self.guarded.lock().unwrap(); - if !inner.buf.is_empty() { - Poll::Ready(WaitState::WantWrite) - } else if inner.closed { - Poll::Ready(WaitState::WantClose) - } else { - inner.register(cx.waker()); - Poll::Pending - } - }) - } -} - -pub(super) enum WaitState { - WantWrite, - WantClose, -} - -pub(super) struct State { - pub(super) closed: bool, - pub(super) buf: BytesMut, - pub(super) res: VecDeque, - pub(super) waker: Option, -} - -impl State { - fn register(&mut self, waker: &Waker) { - self.waker = Some(waker.clone()); - } - - fn wake(&mut self) { - if let Some(waker) = self.waker.take() { - waker.wake(); - } + pub(crate) fn send(&self, buf: BytesMut) -> Result { + let (req, res) = super::codec::request_pair(buf); + self.0.send(req).map_err(|_| DriverDown)?; + Ok(res) } } pub struct GenericDriver { pub(super) io: Io, + pub(super) req: VecDeque, pub(super) read_buf: PagedBytesMut, - pub(super) rx: DriverRx, + pub(super) write_buf: BytesMut, + pub(super) rx: UnboundedReceiver, read_state: ReadState, write_state: WriteState, } @@ -172,29 +81,24 @@ where Io: AsyncIo + Send, { pub(crate) fn new(io: Io) -> (Self, DriverTx) { - let state = Arc::new(SharedState { - guarded: Mutex::new(State { - closed: false, - buf: BytesMut::new(), - res: VecDeque::new(), - waker: None, - }), - }); + let (tx, rx) = unbounded_channel(); ( Self { io, - rx: DriverRx(state.clone()), + rx, + req: VecDeque::new(), read_buf: PagedBytesMut::new(), + write_buf: BytesMut::new(), read_state: ReadState::WantRead, write_state: WriteState::Waiting, }, - DriverTx(state), + DriverTx(tx), ) } pub(crate) async fn send(&mut self, msg: BytesMut) -> Result<(), Error> { - self.rx.guarded.lock().unwrap().buf.extend_from_slice(&msg); + self.write_buf.unsplit(msg); self.write_state = WriteState::WantWrite; loop { self.try_write()?; @@ -211,16 +115,16 @@ where async fn _try_next(&mut self) -> Result, Error> { loop { - if let Some(msg) = self.rx.try_decode(self.read_buf.get_mut())? { + if let Some(msg) = try_decode(&mut self.req, &mut self.read_buf.get_mut())? { return Ok(Some(msg)); } let res = match (&mut self.read_state, &mut self.write_state) { (ReadState::WantRead, WriteState::Waiting) => { - self.io.ready(Interest::READABLE).select(self.rx.wait()).await + self.io.ready(Interest::READABLE).select(self.rx.recv()).await } (ReadState::WantRead, WriteState::WantWrite | WriteState::WantFlush) => { - SelectOutput::A(self.io.ready(INTEREST_READ_WRITE).await) + self.io.ready(INTEREST_READ_WRITE).select(self.rx.recv()).await } (ReadState::WantRead, WriteState::Closed(_)) => { SelectOutput::A(self.io.ready(Interest::READABLE).await) @@ -228,7 +132,7 @@ where (ReadState::Closed(_), WriteState::WantFlush | WriteState::WantWrite) => { SelectOutput::A(self.io.ready(Interest::WRITABLE).await) } - (ReadState::Closed(_), WriteState::Waiting) => SelectOutput::B(self.rx.wait().await), + (ReadState::Closed(_), WriteState::Waiting) => SelectOutput::B(self.rx.recv().await), (ReadState::Closed(None), WriteState::Closed(None)) => { poll_fn(|cx| Pin::new(&mut self.io).poll_shutdown(cx)).await?; return Ok(None); @@ -255,8 +159,12 @@ where } } } - SelectOutput::B(WaitState::WantWrite) => self.write_state = WriteState::WantWrite, - SelectOutput::B(WaitState::WantClose) => self.write_state = WriteState::Closed(None), + SelectOutput::B(Some(msg)) => { + self.write_buf.unsplit(msg.buf); + self.write_state = WriteState::WantWrite; + self.req.push_back(msg.tx); + } + SelectOutput::B(None) => self.write_state = WriteState::Closed(None), } } } @@ -297,26 +205,23 @@ where ); if matches!(self.write_state, WriteState::WantWrite) { - let mut inner = self.rx.guarded.lock().unwrap(); - let mut written = 0; - loop { - match io::Write::write(&mut self.io, &inner.buf[written..]) { + match io::Write::write(&mut self.io, &self.write_buf[written..]) { Ok(0) => { - inner.buf.advance(written); + self.write_buf.advance(written); return Err(io::ErrorKind::WriteZero.into()); } Ok(n) => { written += n; - if written == inner.buf.len() { - inner.buf.clear(); + if written == self.write_buf.len() { + self.write_buf.clear(); self.write_state = WriteState::WantFlush; break; } } Err(e) => { - inner.buf.advance(written); + self.write_buf.advance(written); return if matches!(e.kind(), io::ErrorKind::WouldBlock) { Ok(()) } else { @@ -345,40 +250,31 @@ where #[cold] #[inline(never)] fn on_write_err(&mut self, e: io::Error) { - { - // when write error occur the driver would go into half close state(read only). - // clearing write_buf would drop all pending requests in it and hint the driver - // no future Interest::WRITABLE should be passed to AsyncIo::ready method. - let mut inner = self.rx.guarded.lock().unwrap(); - inner.buf.clear(); - // close shared state early so driver tx can observe the shutdown in first hand - inner.closed = true; - } + // when write error occur the driver would go into half close state(read only). + // clearing all pending requests in channel to notify task waiting for response + self.rx.close(); self.write_state = WriteState::Closed(Some(e)); } } -impl DriverRx { - pub(super) fn try_decode(&self, read_buf: &mut BytesMut) -> Result, Error> { - let mut guard = None; - - while let Some(res) = ResponseMessage::try_from_buf(read_buf)? { - match res { - ResponseMessage::Normal(mut msg) => { - // lock the shared state only when needed and keep the lock around a bit for possible multiple messages - let inner = guard.get_or_insert_with(|| self.guarded.lock().unwrap()); - let res = inner.res.pop_front().ok_or_else(|| msg.parse_error())?; - let _ = res.send(msg.buf); - if !msg.complete { - inner.res.push_front(res); - } +pub(super) fn try_decode( + queue: &mut VecDeque, + read_buf: &mut BytesMut, +) -> Result, Error> { + while let Some(res) = ResponseMessage::try_from_buf(read_buf)? { + match res { + ResponseMessage::Normal(mut msg) => { + let req = queue.pop_front().ok_or_else(|| msg.parse_error())?; + let _ = req.send(msg.buf); + if !msg.complete { + queue.push_front(req); } - ResponseMessage::Async(msg) => return Ok(Some(msg)), } + ResponseMessage::Async(msg) => return Ok(Some(msg)), } - - Ok(None) } + + Ok(None) } impl AsyncLendingIterator for GenericDriver