Skip to content

Commit

Permalink
Merge pull request #564 from benesch/startup-notices
Browse files Browse the repository at this point in the history
Don't suppress notices during startup flow
  • Loading branch information
sfackler authored Jan 31, 2020
2 parents 4bf40cd + 7ea1b2d commit 2ce4f08
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
10 changes: 7 additions & 3 deletions tokio-postgres/src/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol::message::frontend;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand All @@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
}

impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
Expand Down Expand Up @@ -91,6 +92,7 @@ where
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
};

startup(&mut stream, config).await?;
Expand All @@ -99,7 +101,7 @@ where

let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, parameters, receiver);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);

Ok((client, connection))
}
Expand Down Expand Up @@ -332,7 +334,9 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(Message::NoticeResponse(_)) => {}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Expand Down
9 changes: 5 additions & 4 deletions tokio-postgres/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct Connection<S, T> {
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_response: Option<BackendMessage>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
Expand All @@ -64,6 +64,7 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Expand All @@ -72,7 +73,7 @@ where
parameters,
receiver,
pending_request: None,
pending_response: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
Expand All @@ -82,7 +83,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_response.take() {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Expand Down Expand Up @@ -158,7 +159,7 @@ where
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_response = Some(BackendMessage::Normal {
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
Expand Down
39 changes: 39 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,45 @@ async fn copy_out() {
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
}

#[tokio::test]
async fn notices() {
let long_name = "x".repeat(65);
let (client, mut connection) =
connect_raw(&format!("user=postgres application_name={}", long_name,))
.await
.unwrap();

let (tx, rx) = mpsc::unbounded();
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
let connection = stream.forward(tx).map(|r| r.unwrap());
tokio::spawn(connection);

client
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
.await
.unwrap();

drop(client);

let notices = rx
.filter_map(|m| match m {
AsyncMessage::Notice(n) => future::ready(Some(n)),
_ => future::ready(None),
})
.collect::<Vec<_>>()
.await;
assert_eq!(notices.len(), 2);
assert_eq!(
notices[0].message(),
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
);
assert_eq!(
notices[1].message(),
"database \"noexistdb\" does not exist, skipping"
);
}

#[tokio::test]
async fn notifications() {
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
Expand Down

0 comments on commit 2ce4f08

Please sign in to comment.