From a7e5269d655f8dfef46cf334d9710343d671e802 Mon Sep 17 00:00:00 2001 From: Alex Chi Date: Mon, 24 Jul 2023 14:49:03 -0400 Subject: [PATCH] add query_raw_txt for transaction Signed-off-by: Alex Chi --- tokio-postgres/src/client.rs | 85 ++----------------------- tokio-postgres/src/generic_client.rs | 25 ++++++++ tokio-postgres/src/query.rs | 94 +++++++++++++++++++++++++++- tokio-postgres/src/transaction.rs | 10 +++ 4 files changed, 131 insertions(+), 83 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 4db764bd9..44d57ab99 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -7,10 +7,8 @@ use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; -use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; -use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -22,7 +20,7 @@ use crate::{ CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -380,86 +378,11 @@ impl Client { /// to save a roundtrip pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result where - S: AsRef, + S: AsRef + Sync + Send, I: IntoIterator>, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Sync + Send, { - let params = params.into_iter(); - let params_len = params.len(); - - let buf = self.inner.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary - match frontend::bind( - "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement - std::iter::empty(), // all parameters use the default format (text) - params, - |param, buf| match param { - Some(param) => { - buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol::IsNull::No) - } - None => Ok(postgres_protocol::IsNull::Yes), - }, - Some(0), // all text - buf, - ) { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - }?; - - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; - // Execute - frontend::execute("", 0, buf).map_err(Error::encode)?; - // Sync - frontend::sync(buf); - - Ok(buf.split().freeze()) - })?; - - let mut responses = self - .inner - .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + query::query_txt(&self.inner, query, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index b2a907558..1d28f609d 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -133,6 +140,15 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -215,6 +231,15 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index a486b4f88..3dd180142 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,21 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use crate::{Column, Error, Portal, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; +use postgres_types::Type; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -57,6 +61,92 @@ where }) } +pub async fn query_txt( + client: &Arc, + query: S, + params: I, +) -> Result +where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + let params_len = params.len(); + + let buf = client.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + // NB: for some types that function may send a query to the server. At least in + // raw text mode we don't need that info and can skip this. + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + let statement = Statement::new_text(client, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..806196aa3 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,16 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.client.query_raw_txt(query, params).await + } + /// Like `Client::execute`. pub async fn execute( &self,