Skip to content

Commit 7ea1b2d

Browse files
committed
Don't suppress notices during startup flow
NoticeResponses received during the startup flow were previously being dropped on the floor. Instead stash them away so they can be delivered to the user after the startup flow is complete.
1 parent 5429a79 commit 7ea1b2d

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

tokio-postgres/src/connect_raw.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
1313
use postgres_protocol::authentication::sasl::ScramSha256;
1414
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
1515
use postgres_protocol::message::frontend;
16-
use std::collections::HashMap;
16+
use std::collections::{HashMap, VecDeque};
1717
use std::io;
1818
use std::pin::Pin;
1919
use std::task::{Context, Poll};
@@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
2323
pub struct StartupStream<S, T> {
2424
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
2525
buf: BackendMessages,
26+
delayed: VecDeque<BackendMessage>,
2627
}
2728

2829
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@@ -91,6 +92,7 @@ where
9192
let mut stream = StartupStream {
9293
inner: Framed::new(stream, PostgresCodec),
9394
buf: BackendMessages::empty(),
95+
delayed: VecDeque::new(),
9496
};
9597

9698
startup(&mut stream, config).await?;
@@ -99,7 +101,7 @@ where
99101

100102
let (sender, receiver) = mpsc::unbounded();
101103
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
102-
let connection = Connection::new(stream.inner, parameters, receiver);
104+
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
103105

104106
Ok((client, connection))
105107
}
@@ -332,7 +334,9 @@ where
332334
body.value().map_err(Error::parse)?.to_string(),
333335
);
334336
}
335-
Some(Message::NoticeResponse(_)) => {}
337+
Some(msg @ Message::NoticeResponse(_)) => {
338+
stream.delayed.push_back(BackendMessage::Async(msg))
339+
}
336340
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
337341
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
338342
Some(_) => return Err(Error::unexpected_message()),

tokio-postgres/src/connection.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub struct Connection<S, T> {
5252
parameters: HashMap<String, String>,
5353
receiver: mpsc::UnboundedReceiver<Request>,
5454
pending_request: Option<RequestMessages>,
55-
pending_response: Option<BackendMessage>,
55+
pending_responses: VecDeque<BackendMessage>,
5656
responses: VecDeque<Response>,
5757
state: State,
5858
}
@@ -64,6 +64,7 @@ where
6464
{
6565
pub(crate) fn new(
6666
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
67+
pending_responses: VecDeque<BackendMessage>,
6768
parameters: HashMap<String, String>,
6869
receiver: mpsc::UnboundedReceiver<Request>,
6970
) -> Connection<S, T> {
@@ -72,7 +73,7 @@ where
7273
parameters,
7374
receiver,
7475
pending_request: None,
75-
pending_response: None,
76+
pending_responses,
7677
responses: VecDeque::new(),
7778
state: State::Active,
7879
}
@@ -82,7 +83,7 @@ where
8283
&mut self,
8384
cx: &mut Context<'_>,
8485
) -> Poll<Option<Result<BackendMessage, Error>>> {
85-
if let Some(message) = self.pending_response.take() {
86+
if let Some(message) = self.pending_responses.pop_front() {
8687
trace!("retrying pending response");
8788
return Poll::Ready(Some(Ok(message)));
8889
}
@@ -158,7 +159,7 @@ where
158159
}
159160
Poll::Pending => {
160161
self.responses.push_front(response);
161-
self.pending_response = Some(BackendMessage::Normal {
162+
self.pending_responses.push_back(BackendMessage::Normal {
162163
messages,
163164
request_complete,
164165
});

tokio-postgres/tests/test/main.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,45 @@ async fn copy_out() {
570570
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
571571
}
572572

573+
#[tokio::test]
574+
async fn notices() {
575+
let long_name = "x".repeat(65);
576+
let (client, mut connection) =
577+
connect_raw(&format!("user=postgres application_name={}", long_name,))
578+
.await
579+
.unwrap();
580+
581+
let (tx, rx) = mpsc::unbounded();
582+
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
583+
let connection = stream.forward(tx).map(|r| r.unwrap());
584+
tokio::spawn(connection);
585+
586+
client
587+
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
588+
.await
589+
.unwrap();
590+
591+
drop(client);
592+
593+
let notices = rx
594+
.filter_map(|m| match m {
595+
AsyncMessage::Notice(n) => future::ready(Some(n)),
596+
_ => future::ready(None),
597+
})
598+
.collect::<Vec<_>>()
599+
.await;
600+
assert_eq!(notices.len(), 2);
601+
assert_eq!(
602+
notices[0].message(),
603+
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
604+
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
605+
);
606+
assert_eq!(
607+
notices[1].message(),
608+
"database \"noexistdb\" does not exist, skipping"
609+
);
610+
}
611+
573612
#[tokio::test]
574613
async fn notifications() {
575614
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();

0 commit comments

Comments
 (0)