From 0fdb8ba7ea71c36fbbebdb10f3e5425995578bca Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 17 Nov 2023 16:17:38 +0100 Subject: [PATCH] fix: handle description a bit higher in the execution --- tokio-postgres/src/query.rs | 46 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 91421da3b..0d1d7823c 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -53,9 +53,11 @@ where } else { encode(client, &statement, params)? }; - let (statement, responses) = start(client, buf).await?; + + let responses = start(client, buf).await?; + Ok(RowStream { - statement, + statement: None, responses, rows_affected: None, command_tag: None, @@ -116,11 +118,11 @@ where })?; // now read the responses - let (statement, responses) = start(client, buf).await?; + let responses = start(client, buf).await?; Ok(RowStream { parameter_description: None, - statement, + statement: None, responses, command_tag: None, status: None, @@ -189,7 +191,8 @@ where } else { encode(client, &statement, params)? }; - let (_statement, mut responses) = start(client, buf).await?; + + let mut responses = start(client, buf).await?; let mut rows = 0; loop { @@ -205,27 +208,13 @@ where } } -async fn start(client: &InnerClient, buf: Bytes) -> Result<(Option, Responses), Error> { - let mut parameter_description: Option = None; - let mut statement = None; +async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; loop { match responses.next().await? { Message::ParseComplete => {} - Message::BindComplete => return Ok((statement, responses)), - Message::ParameterDescription(body) => { - parameter_description = Some(body); // tooo-o-ooo-o loooove - } - Message::NoData => { - statement = Some(make_statement(parameter_description.take().unwrap(), None)?); - } - Message::RowDescription(body) => { - statement = Some(make_statement( - parameter_description.take().unwrap(), - Some(body), - )?); - } + Message::BindComplete => return Ok(responses), m => return Err(Error::unexpected_message(m)), } } @@ -360,6 +349,21 @@ impl Stream for RowStream { *this.command_tag = Some(tag.to_string()); } } + Message::ParameterDescription(body) => { + *this.parameter_description = Some(body); + } + Message::NoData => { + *this.statement = Some(make_statement( + this.parameter_description.take().unwrap(), + None, + )?); + } + Message::RowDescription(body) => { + *this.statement = Some(make_statement( + this.parameter_description.take().unwrap(), + Some(body), + )?); + } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::ReadyForQuery(status) => { *this.status = Some(status.status());