diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ca030d26..c6e0f00de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,7 +57,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.65.0 + version: 1.67.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 44d57ab99..38588f8f7 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -376,13 +376,19 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + pub async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, query, params).await + let statement = statement.__convert().into_statement(self).await?; + query::query_txt(&self.inner, statement, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 1d28f609d..e366e2449 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -57,8 +57,13 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator; /// Like `Client::query_raw_txt`. - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; @@ -140,13 +145,14 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { @@ -231,13 +237,14 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 3dd180142..f4abd41a5 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,15 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Column, Error, Portal, Row, Statement}; +use crate::{Error, Portal, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; -use postgres_types::Type; +use postgres_types::Format; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; @@ -57,30 +55,29 @@ where statement, responses, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } pub async fn query_txt( client: &Arc, - query: S, + statement: Statement, params: I, ) -> Result where - S: AsRef + Sync + Send, + S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); - let params_len = params.len(); let buf = client.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; // Bind, pass params as text, retrieve as binary match frontend::bind( "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement + statement.name(), // named prepared statement std::iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { @@ -98,8 +95,6 @@ where Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), }?; - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; // Execute frontend::execute("", 0, buf).map_err(Error::encode)?; // Sync @@ -108,43 +103,16 @@ where Ok(buf.split().freeze()) })?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(client, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: None, + output_format: Format::Text, + _p: PhantomPinned, + }) } pub async fn query_portal( @@ -164,6 +132,8 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -295,23 +265,13 @@ pin_project! { statement: Statement, responses: Responses, command_tag: Option, + output_format: Format, + status: Option, #[pin] _p: PhantomPinned, } } -impl RowStream { - /// Creates a new `RowStream`. - pub fn new(statement: Statement, responses: Responses) -> Self { - RowStream { - statement, - responses, - command_tag: None, - _p: PhantomPinned, - } - } -} - impl Stream for RowStream { type Item = Result; @@ -320,7 +280,11 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { @@ -328,7 +292,10 @@ impl Stream for RowStream { *this.command_tag = Some(tag.to_string()); } } - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(status) => { + *this.status = Some(status.status()); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } @@ -342,4 +309,11 @@ impl RowStream { pub fn command_tag(&self) -> Option { self.command_tag.clone() } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> Option { + self.status + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index ce4efed7e..754b5f28c 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -98,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -111,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -193,7 +199,7 @@ impl Row { /// /// Useful when using query_raw_txt() which sets text transfer mode pub fn as_text(&self, idx: usize) -> Result, Error> { - if self.statement.output_format() == Format::Text { + if self.output_format == Format::Text { match self.col_buffer(idx) { Some(raw) => { FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 8743f00f0..246d36a57 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -6,7 +6,6 @@ use postgres_protocol::{ message::{backend::Field, frontend}, Oid, }; -use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -17,7 +16,6 @@ struct StatementInner { name: String, params: Vec, columns: Vec, - output_format: Format, } impl Drop for StatementInner { @@ -51,22 +49,6 @@ impl Statement { name, params, columns, - output_format: Format::Binary, - })) - } - - pub(crate) fn new_text( - inner: &Arc, - name: String, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, - params, - columns, - output_format: Format::Text, })) } @@ -83,11 +65,6 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } - - /// Returns output format for the statement. - pub fn output_format(&self) -> Format { - self.0.output_format - } } /// Information about a column of a query. diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 806196aa3..ca386974e 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -150,13 +150,14 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, query: S, params: I) -> Result + pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - self.client.query_raw_txt(query, params).await + self.client.query_raw_txt(statement, params).await } /// Like `Client::execute`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 20345c7b4..0a8aff1d7 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -314,7 +314,7 @@ async fn query_raw_txt_nulls() { async fn limit_max_backend_message_size() { let client = connect("user=postgres max_backend_message_size=10000").await; let small: Vec = client - .query_raw_txt("SELECT REPEAT('a', 20)", []) + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -325,7 +325,7 @@ async fn limit_max_backend_message_size() { assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); let large: Result, Error> = client - .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -339,7 +339,7 @@ async fn command_tag() { let client = connect("user=postgres").await; let row_stream = client - .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) .await .unwrap(); @@ -353,12 +353,40 @@ async fn command_tag() { assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); } +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'T')); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'I')); +} + #[tokio::test] async fn column_extras() { let client = connect("user=postgres").await; let rows: Vec = client - .query_raw_txt("select relacl, relname from pg_class limit 1", []) + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) .await .unwrap() .try_collect()