diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 29cac840d..34b60537c 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -439,7 +439,9 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.connection.block_on(self.client.batch_execute(query)) + self.connection + .block_on(self.client.batch_execute(query)) + .map(|_| ()) } /// Begins a new database transaction. diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 17c49c406..3d147cd01 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -35,6 +35,7 @@ impl<'a> Transaction<'a> { pub fn commit(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().commit()) + .map(|_| ()) } /// Rolls the transaction back, discarding all changes made within it. @@ -43,6 +44,7 @@ impl<'a> Transaction<'a> { pub fn rollback(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().rollback()) + .map(|_| ()) } /// Like `Client::prepare`. @@ -193,6 +195,7 @@ impl<'a> Transaction<'a> { pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { self.connection .block_on(self.transaction.as_ref().unwrap().batch_execute(query)) + .map(|_| ()) } /// Like `Client::cancel_token`. diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 68737f738..39e41d85c 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -55,7 +55,7 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } -socket2 = "0.4" +socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.0", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 38588f8f7..f75fb4c1a 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -17,8 +17,8 @@ use crate::types::{Oid, ToSql, Type}; use crate::Socket; use crate::{ copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, - CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, - TransactionBuilder, + CopyInSink, Error, ReadyForQueryStatus, Row, SimpleQueryMessage, Statement, ToStatement, + Transaction, TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -526,7 +526,7 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { simple_query::batch_execute(self.inner(), query).await } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 17bb28409..94af11658 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -119,6 +119,8 @@ #![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use postgres_protocol::message::backend::ReadyForQueryBody; + pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; @@ -143,6 +145,31 @@ pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + pub mod binary_copy; mod bind; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index f4abd41a5..9df48b773 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Error, Portal, ReadyForQueryStatus, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; @@ -55,7 +55,7 @@ where statement, responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Binary, _p: PhantomPinned, }) @@ -109,7 +109,7 @@ where statement, responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Text, _p: PhantomPinned, }) @@ -132,7 +132,7 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Binary, _p: PhantomPinned, }) @@ -266,7 +266,7 @@ pin_project! { responses: Responses, command_tag: Option, output_format: Format, - status: Option, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } @@ -293,7 +293,7 @@ impl Stream for RowStream { } } Message::ReadyForQuery(status) => { - *this.status = Some(status.status()); + *this.status = status.into(); return Poll::Ready(None); } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), @@ -313,7 +313,7 @@ impl RowStream { /// Returns if the connection is ready for querying, with the status of the connection. /// /// This might be available only after the stream has been exhausted. - pub fn ready_status(&self) -> Option { + pub fn ready_status(&self) -> ReadyForQueryStatus { self.status } } diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 70f48a7d8..28e7010eb 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{Error, SimpleQueryMessage, SimpleQueryRow}; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; @@ -40,11 +40,15 @@ pub async fn simple_query(client: &InnerClient, query: &str) -> Result Result<(), Error> { +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { debug!("executing statement batch: {}", query); let buf = encode(client, query)?; @@ -52,7 +56,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro loop { match responses.next().await? { - Message::ReadyForQuery(_) => return Ok(()), + Message::ReadyForQuery(status) => return Ok(status.into()), Message::CommandComplete(_) | Message::EmptyQueryResponse | Message::RowDescription(_) @@ -74,11 +78,21 @@ pin_project! { pub struct SimpleQueryStream { responses: Responses, columns: Option>, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } } +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + impl Stream for SimpleQueryStream { type Item = Result; @@ -117,7 +131,10 @@ impl Stream for SimpleQueryStream { }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index ca386974e..c6a42dd7d 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -9,8 +9,8 @@ use crate::types::{BorrowToSql, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row, - SimpleQueryMessage, Statement, ToStatement, + bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, ReadyForQueryStatus, + Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::Buf; use futures_util::TryStreamExt; @@ -65,7 +65,7 @@ impl<'a> Transaction<'a> { } /// Consumes the transaction, committing all changes made within it. - pub async fn commit(mut self) -> Result<(), Error> { + pub async fn commit(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("RELEASE {}", sp.name) @@ -78,7 +78,7 @@ impl<'a> Transaction<'a> { /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. - pub async fn rollback(mut self) -> Result<(), Error> { + pub async fn rollback(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("ROLLBACK TO {}", sp.name) @@ -261,7 +261,7 @@ impl<'a> Transaction<'a> { } /// Like `Client::batch_execute`. - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { self.client.batch_execute(query).await } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0a8aff1d7..e21e2379c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -16,7 +16,8 @@ use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::types::{Kind, Type}; use tokio_postgres::{ - AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, + AsyncMessage, Client, Config, Connection, Error, IsolationLevel, ReadyForQueryStatus, + SimpleQueryMessage, }; mod binary_copy; @@ -365,7 +366,7 @@ async fn ready_for_query() { pin_mut!(row_stream); while row_stream.next().await.is_none() {} - assert_eq!(row_stream.ready_status(), Some(b'T')); + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Transaction); let row_stream = client .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) @@ -375,7 +376,7 @@ async fn ready_for_query() { pin_mut!(row_stream); while row_stream.next().await.is_none() {} - assert_eq!(row_stream.ready_status(), Some(b'I')); + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Idle); } #[tokio::test]