Skip to content

Commit 2ce4f08

Browse files
authored
Merge pull request #564 from benesch/startup-notices
Don't suppress notices during startup flow
2 parents 4bf40cd + 7ea1b2d commit 2ce4f08

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

tokio-postgres/src/connect_raw.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
1313
use postgres_protocol::authentication::sasl::ScramSha256;
1414
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
1515
use postgres_protocol::message::frontend;
16-
use std::collections::HashMap;
16+
use std::collections::{HashMap, VecDeque};
1717
use std::io;
1818
use std::pin::Pin;
1919
use std::task::{Context, Poll};
@@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
2323
pub struct StartupStream<S, T> {
2424
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
2525
buf: BackendMessages,
26+
delayed: VecDeque<BackendMessage>,
2627
}
2728

2829
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@@ -91,6 +92,7 @@ where
9192
let mut stream = StartupStream {
9293
inner: Framed::new(stream, PostgresCodec),
9394
buf: BackendMessages::empty(),
95+
delayed: VecDeque::new(),
9496
};
9597

9698
startup(&mut stream, config).await?;
@@ -99,7 +101,7 @@ where
99101

100102
let (sender, receiver) = mpsc::unbounded();
101103
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
102-
let connection = Connection::new(stream.inner, parameters, receiver);
104+
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
103105

104106
Ok((client, connection))
105107
}
@@ -332,7 +334,9 @@ where
332334
body.value().map_err(Error::parse)?.to_string(),
333335
);
334336
}
335-
Some(Message::NoticeResponse(_)) => {}
337+
Some(msg @ Message::NoticeResponse(_)) => {
338+
stream.delayed.push_back(BackendMessage::Async(msg))
339+
}
336340
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
337341
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
338342
Some(_) => return Err(Error::unexpected_message()),

tokio-postgres/src/connection.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub struct Connection<S, T> {
5252
parameters: HashMap<String, String>,
5353
receiver: mpsc::UnboundedReceiver<Request>,
5454
pending_request: Option<RequestMessages>,
55-
pending_response: Option<BackendMessage>,
55+
pending_responses: VecDeque<BackendMessage>,
5656
responses: VecDeque<Response>,
5757
state: State,
5858
}
@@ -64,6 +64,7 @@ where
6464
{
6565
pub(crate) fn new(
6666
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
67+
pending_responses: VecDeque<BackendMessage>,
6768
parameters: HashMap<String, String>,
6869
receiver: mpsc::UnboundedReceiver<Request>,
6970
) -> Connection<S, T> {
@@ -72,7 +73,7 @@ where
7273
parameters,
7374
receiver,
7475
pending_request: None,
75-
pending_response: None,
76+
pending_responses,
7677
responses: VecDeque::new(),
7778
state: State::Active,
7879
}
@@ -82,7 +83,7 @@ where
8283
&mut self,
8384
cx: &mut Context<'_>,
8485
) -> Poll<Option<Result<BackendMessage, Error>>> {
85-
if let Some(message) = self.pending_response.take() {
86+
if let Some(message) = self.pending_responses.pop_front() {
8687
trace!("retrying pending response");
8788
return Poll::Ready(Some(Ok(message)));
8889
}
@@ -158,7 +159,7 @@ where
158159
}
159160
Poll::Pending => {
160161
self.responses.push_front(response);
161-
self.pending_response = Some(BackendMessage::Normal {
162+
self.pending_responses.push_back(BackendMessage::Normal {
162163
messages,
163164
request_complete,
164165
});

tokio-postgres/tests/test/main.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,45 @@ async fn copy_out() {
570570
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
571571
}
572572

573+
#[tokio::test]
574+
async fn notices() {
575+
let long_name = "x".repeat(65);
576+
let (client, mut connection) =
577+
connect_raw(&format!("user=postgres application_name={}", long_name,))
578+
.await
579+
.unwrap();
580+
581+
let (tx, rx) = mpsc::unbounded();
582+
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
583+
let connection = stream.forward(tx).map(|r| r.unwrap());
584+
tokio::spawn(connection);
585+
586+
client
587+
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
588+
.await
589+
.unwrap();
590+
591+
drop(client);
592+
593+
let notices = rx
594+
.filter_map(|m| match m {
595+
AsyncMessage::Notice(n) => future::ready(Some(n)),
596+
_ => future::ready(None),
597+
})
598+
.collect::<Vec<_>>()
599+
.await;
600+
assert_eq!(notices.len(), 2);
601+
assert_eq!(
602+
notices[0].message(),
603+
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
604+
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
605+
);
606+
assert_eq!(
607+
notices[1].message(),
608+
"database \"noexistdb\" does not exist, skipping"
609+
);
610+
}
611+
573612
#[tokio::test]
574613
async fn notifications() {
575614
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();

0 commit comments

Comments
 (0)