Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ nightly = []
xitca-io = { version = "0.5.1", features = ["runtime"] }
xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] }

byteorder = "1.5.0"
fallible-iterator = "0.2"
futures-core = { version = "0.3", default-features = false }
lru = { version = "0.16", default-features = false }
memchr = "2.7.1"
percent-encoding = "2"
postgres-protocol = "0.6.5"
postgres-types = "0.2"
Expand Down
88 changes: 25 additions & 63 deletions postgres/src/driver/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,27 @@ use crate::{

pub(super) fn request_pair() -> (ResponseSender, Response) {
let (tx, rx) = unbounded_channel();
(
tx,
Response {
rx,
buf: BytesMut::new(),
},
)
(tx, Response { rx })
}

#[derive(Debug)]
pub struct Response {
rx: ResponseReceiver,
buf: BytesMut,
}

impl Response {
pub(crate) fn blocking_recv(&mut self) -> Result<backend::Message, Error> {
if self.buf.is_empty() {
self.buf = self.rx.blocking_recv().ok_or_else(|| Error::from(ClosedByDriver))?;
}
self.parse_message()
let msg = self.rx.blocking_recv().ok_or_else(|| Error::from(ClosedByDriver))?;
Self::parse_message(msg)
}

pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
poll_fn(|cx| self.poll_recv(cx))
}

pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<backend::Message, Error>> {
if self.buf.is_empty() {
self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::from(ClosedByDriver))?;
}
Poll::Ready(self.parse_message())
let msg = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::from(ClosedByDriver))?;
Poll::Ready(Self::parse_message(msg))
}

pub(crate) fn try_into_row_affected(mut self) -> impl Future<Output = Result<u64, Error>> + Send {
Expand Down Expand Up @@ -100,39 +89,39 @@ impl Response {
}
}

fn parse_message(&mut self) -> Result<backend::Message, Error> {
match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") {
fn parse_message(msg: backend::MessageRaw) -> Result<backend::Message, Error> {
match msg.try_into_message()? {
backend::Message::ErrorResponse(body) => Err(Error::db(body.fields())),
msg => Ok(msg),
}
}
}

// Extract the number of rows affected.
pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result<u64, Error> {
fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result<u64, Error> {
body.tag()
.map_err(|_| Error::todo())
.map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0))
}

pub(super) type ResponseSender = UnboundedSender<BytesMut>;
pub(super) type ResponseSender = UnboundedSender<backend::MessageRaw>;

// TODO: remove this lint.
#[allow(dead_code)]
pub(super) type ResponseReceiver = UnboundedReceiver<BytesMut>;
pub(super) type ResponseReceiver = UnboundedReceiver<backend::MessageRaw>;

pub(super) struct BytesMessage {
pub(super) buf: BytesMut,
pub(super) msg: backend::MessageRaw,
pub(super) complete: bool,
}

impl BytesMessage {
#[cold]
#[inline(never)]
pub(super) fn parse_error(&mut self) -> Error {
match backend::Message::parse(&mut self.buf) {
pub(super) fn into_error(self) -> Error {
match self.msg.try_into_message() {
Err(e) => Error::from(e),
Ok(Some(backend::Message::ErrorResponse(body))) => Error::db(body.fields()),
Ok(backend::Message::ErrorResponse(body)) => Error::db(body.fields()),
_ => Error::unexpected(),
}
}
Expand All @@ -145,47 +134,20 @@ pub(super) enum ResponseMessage {

impl ResponseMessage {
pub(crate) fn try_from_buf(buf: &mut BytesMut) -> Result<Option<Self>, Error> {
let mut tail = 0;
let mut complete = false;

loop {
let slice = &buf[tail..];
let Some(header) = backend::Header::parse(slice)? else {
break;
};
let len = header.len() as usize + 1;

if slice.len() < len {
break;
let Some(msg) = backend::MessageRaw::parse(buf)? else {
return Ok(None);
};

match msg.tag() {
backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => {
let message = msg.try_into_message()?;
Ok(Some(ResponseMessage::Async(message)))
}

match header.tag() {
backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => {
if tail > 0 {
break;
}
let message = backend::Message::parse(buf)?
.expect("buffer contains at least one Message. parser must produce Some");
return Ok(Some(ResponseMessage::Async(message)));
}
tag => {
tail += len;
if matches!(tag, backend::READY_FOR_QUERY_TAG) {
complete = true;
break;
}
}
tag => {
let complete = matches!(tag, backend::READY_FOR_QUERY_TAG);
Ok(Some(ResponseMessage::Normal(BytesMessage { msg, complete })))
}
}

if tail == 0 {
Ok(None)
} else {
Ok(Some(ResponseMessage::Normal(BytesMessage {
buf: buf.split_to(tail),
complete,
})))
}
}
}

Expand Down
26 changes: 13 additions & 13 deletions postgres/src/driver/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
frontend::describe(b'S', name, buf)?;
protocol::sync(buf);
frontend::sync(buf);
Ok(())
}

Expand All @@ -94,8 +94,8 @@ impl Encode for StatementPreparedCancel<'_> {
#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
let Self { name } = self;
protocol::close(b'S', name, buf)?;
protocol::sync(buf);
frontend::close(b'S', name, buf)?;
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
Expand Down Expand Up @@ -135,8 +135,8 @@ where
P: AsParams,
{
encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
protocol::execute("", 0, buf)?;
protocol::sync(buf);
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(())
}

Expand All @@ -156,8 +156,8 @@ where
frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
encode_bind("", types, params, "", buf)?;
frontend::describe(b'S', "", buf)?;
protocol::execute("", 0, buf)?;
protocol::sync(buf);
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(IntoRowStreamGuard(cli))
}
}
Expand Down Expand Up @@ -186,7 +186,7 @@ where
params,
} = self;
encode_bind(stmt, types, params, name, buf)?;
protocol::sync(buf);
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
Expand All @@ -202,8 +202,8 @@ impl Encode for PortalCancel<'_> {

#[inline]
fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
protocol::close(b'P', self.name, buf)?;
protocol::sync(buf);
frontend::close(b'P', self.name, buf)?;
frontend::sync(buf);
Ok(NoOpIntoRowStream)
}
}
Expand All @@ -226,8 +226,8 @@ impl<'s> Encode for PortalQuery<'s> {
max_rows,
columns,
} = self;
protocol::execute(name, max_rows, buf)?;
protocol::sync(buf);
frontend::execute(name, max_rows, buf)?;
frontend::sync(buf);
Ok(columns)
}
}
Expand All @@ -252,7 +252,7 @@ where

let params = params.zip(types);

protocol::bind(
frontend::bind(
portal_name,
stmt_name,
params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
Expand Down
20 changes: 14 additions & 6 deletions postgres/src/driver/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ pub(super) struct State {

impl State {
fn register(&mut self, waker: &Waker) {
self.waker = Some(waker.clone());
match self.waker {
Some(ref w) if w.will_wake(waker) => {}
_ => self.waker = Some(waker.clone()),
};
}

fn wake(&mut self) {
Expand Down Expand Up @@ -351,13 +354,18 @@ impl DriverRx {

while let Some(res) = ResponseMessage::try_from_buf(read_buf)? {
match res {
ResponseMessage::Normal(mut msg) => {
ResponseMessage::Normal(msg) => {
// lock the shared state only when needed and keep the lock around a bit for possible multiple messages
let inner = guard.get_or_insert_with(|| self.guarded.lock().unwrap());
let res = inner.res.pop_front().ok_or_else(|| msg.parse_error())?;
let _ = res.send(msg.buf);
if !msg.complete {
inner.res.push_front(res);

match inner.res.front_mut() {
Some(tx) => {
let _ = tx.send(msg.msg);
if msg.complete {
inner.res.pop_front();
}
}
None => return Err(msg.into_error()),
}
}
ResponseMessage::Async(msg) => return Ok(Some(msg)),
Expand Down
Loading
Loading