Skip to content

Commit f397668

Browse files
committed
Work with pools that don't support prepared statements
Introduce a new `query_with_param_types` method that allows to specify Postgres type parameters. This obviated the need to use prepared statementsjust to obtain parameter types for a query. It then combines parse, bind, and execute in a single packet. Related: sfackler#1017, sfackler#1067
1 parent 98f5a11 commit f397668

File tree

7 files changed

+416
-6
lines changed

7 files changed

+416
-6
lines changed

tokio-postgres/src/client.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,88 @@ impl Client {
364364
query::query(&self.inner, statement, params).await
365365
}
366366

367+
/// Like `query`, but requires the types of query parameters to be explicitly specified.
368+
///
369+
/// Compared to `query`, this method allows performing queries without three round trips (for prepare, execute, and close). Thus,
370+
/// this is suitable in environments where prepared statements aren't supported (such as Cloudflare Workers with Hyperdrive).
371+
///
372+
/// # Examples
373+
///
374+
/// ```no_run
375+
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
376+
/// use tokio_postgres::types::ToSql;
377+
/// use tokio_postgres::types::Type;
378+
/// use futures_util::{pin_mut, TryStreamExt};
379+
///
380+
/// let rows = client.query_with_param_types(
381+
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
382+
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
383+
/// ).await?;
384+
///
385+
/// for row in rows {
386+
/// let foo: i32 = row.get("foo");
387+
/// println!("foo: {}", foo);
388+
/// }
389+
/// # Ok(())
390+
/// # }
391+
/// ```
392+
pub async fn query_with_param_types(
393+
&self,
394+
statement: &str,
395+
params: &[(&(dyn ToSql + Sync), Type)],
396+
) -> Result<Vec<Row>, Error> {
397+
self.query_raw_with_param_types(statement, params)
398+
.await?
399+
.try_collect()
400+
.await
401+
}
402+
403+
/// The maximally flexible version of [`query_with_param_types`].
404+
///
405+
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
406+
/// provided, 1-indexed.
407+
///
408+
/// The parameters must specify value along with their Postgres type. This allows performing
409+
/// queries without three round trips (for prepare, execute, and close).
410+
///
411+
/// [`query_with_param_types`]: #method.query_with_param_types
412+
///
413+
/// # Examples
414+
///
415+
/// ```no_run
416+
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
417+
/// use tokio_postgres::types::ToSql;
418+
/// use tokio_postgres::types::Type;
419+
/// use futures_util::{pin_mut, TryStreamExt};
420+
///
421+
/// let mut it = client.query_raw_with_param_types(
422+
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
423+
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
424+
/// ).await?;
425+
///
426+
/// pin_mut!(it);
427+
/// while let Some(row) = it.try_next().await? {
428+
/// let foo: i32 = row.get("foo");
429+
/// println!("foo: {}", foo);
430+
/// }
431+
/// # Ok(())
432+
/// # }
433+
/// ```
434+
pub async fn query_raw_with_param_types(
435+
&self,
436+
statement: &str,
437+
params: &[(&(dyn ToSql + Sync), Type)],
438+
) -> Result<RowStream, Error> {
439+
fn slice_iter<'a>(
440+
s: &'a [(&'a (dyn ToSql + Sync), Type)],
441+
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
442+
s.iter()
443+
.map(|(param, param_type)| (*param as _, param_type.clone()))
444+
}
445+
446+
query::query_with_param_types(&self.inner, statement, slice_iter(params)).await
447+
}
448+
367449
/// Executes a statement, returning the number of rows modified.
368450
///
369451
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/generic_client.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ pub trait GenericClient: private::Sealed {
5656
I: IntoIterator<Item = P> + Sync + Send,
5757
I::IntoIter: ExactSizeIterator;
5858

59+
/// Like `Client::query_with_param_types`
60+
async fn query_with_param_types(
61+
&self,
62+
statement: &str,
63+
params: &[(&(dyn ToSql + Sync), Type)],
64+
) -> Result<Vec<Row>, Error>;
65+
66+
/// Like `Client::query_raw_with_param_types`.
67+
async fn query_raw_with_param_types(
68+
&self,
69+
statement: &str,
70+
params: &[(&(dyn ToSql + Sync), Type)],
71+
) -> Result<RowStream, Error>;
72+
5973
/// Like `Client::prepare`.
6074
async fn prepare(&self, query: &str) -> Result<Statement, Error>;
6175

@@ -136,6 +150,22 @@ impl GenericClient for Client {
136150
self.query_raw(statement, params).await
137151
}
138152

153+
async fn query_with_param_types(
154+
&self,
155+
statement: &str,
156+
params: &[(&(dyn ToSql + Sync), Type)],
157+
) -> Result<Vec<Row>, Error> {
158+
self.query_with_param_types(statement, params).await
159+
}
160+
161+
async fn query_raw_with_param_types(
162+
&self,
163+
statement: &str,
164+
params: &[(&(dyn ToSql + Sync), Type)],
165+
) -> Result<RowStream, Error> {
166+
self.query_raw_with_param_types(statement, params).await
167+
}
168+
139169
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
140170
self.prepare(query).await
141171
}
@@ -222,6 +252,22 @@ impl GenericClient for Transaction<'_> {
222252
self.query_raw(statement, params).await
223253
}
224254

255+
async fn query_with_param_types(
256+
&self,
257+
statement: &str,
258+
params: &[(&(dyn ToSql + Sync), Type)],
259+
) -> Result<Vec<Row>, Error> {
260+
self.query_with_param_types(statement, params).await
261+
}
262+
263+
async fn query_raw_with_param_types(
264+
&self,
265+
statement: &str,
266+
params: &[(&(dyn ToSql + Sync), Type)],
267+
) -> Result<RowStream, Error> {
268+
self.query_raw_with_param_types(statement, params).await
269+
}
270+
225271
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
226272
self.prepare(query).await
227273
}

tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
131131
})
132132
}
133133

134-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
134+
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
135135
if let Some(type_) = Type::from_oid(oid) {
136136
return Ok(type_);
137137
}

tokio-postgres/src/query.rs

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
4+
use crate::prepare::get_type;
45
use crate::types::{BorrowToSql, IsNull};
5-
use crate::{Error, Portal, Row, Statement};
6+
use crate::{Column, Error, Portal, Row, Statement};
67
use bytes::{Bytes, BytesMut};
8+
use fallible_iterator::FallibleIterator;
79
use futures_util::{ready, Stream};
810
use log::{debug, log_enabled, Level};
911
use pin_project_lite::pin_project;
10-
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
12+
use postgres_protocol::message::backend::{CommandCompleteBody, Message, RowDescriptionBody};
1113
use postgres_protocol::message::frontend;
14+
use postgres_types::Type;
1215
use std::fmt;
1316
use std::marker::PhantomPinned;
1417
use std::pin::Pin;
18+
use std::sync::Arc;
1519
use std::task::{Context, Poll};
1620

1721
struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
@@ -50,13 +54,125 @@ where
5054
};
5155
let responses = start(client, buf).await?;
5256
Ok(RowStream {
53-
statement,
57+
statement: statement,
5458
responses,
5559
rows_affected: None,
5660
_p: PhantomPinned,
5761
})
5862
}
5963

64+
enum QueryProcessingState {
65+
Empty,
66+
ParseCompleted,
67+
BindCompleted,
68+
ParameterDescribed,
69+
Final(Vec<Column>),
70+
}
71+
72+
/// State machine for processing messages for `query_with_param_types`.
73+
impl QueryProcessingState {
74+
pub async fn process_message(
75+
self,
76+
client: &Arc<InnerClient>,
77+
message: Message,
78+
) -> Result<Self, Error> {
79+
match (self, message) {
80+
(QueryProcessingState::Empty, Message::ParseComplete) => {
81+
Ok(QueryProcessingState::ParseCompleted)
82+
}
83+
(QueryProcessingState::ParseCompleted, Message::BindComplete) => {
84+
Ok(QueryProcessingState::BindCompleted)
85+
}
86+
(QueryProcessingState::BindCompleted, Message::ParameterDescription(_)) => {
87+
Ok(QueryProcessingState::ParameterDescribed)
88+
}
89+
(
90+
QueryProcessingState::ParameterDescribed,
91+
Message::RowDescription(row_description),
92+
) => Self::form_final(client, Some(row_description)).await,
93+
(QueryProcessingState::ParameterDescribed, Message::NoData) => {
94+
Self::form_final(client, None).await
95+
}
96+
(_, Message::ErrorResponse(body)) => Err(Error::db(body)),
97+
_ => Err(Error::unexpected_message()),
98+
}
99+
}
100+
101+
async fn form_final(
102+
client: &Arc<InnerClient>,
103+
row_description: Option<RowDescriptionBody>,
104+
) -> Result<Self, Error> {
105+
let mut columns = vec![];
106+
if let Some(row_description) = row_description {
107+
let mut it = row_description.fields();
108+
while let Some(field) = it.next().map_err(Error::parse)? {
109+
let type_ = get_type(client, field.type_oid()).await?;
110+
let column = Column {
111+
name: field.name().to_string(),
112+
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
113+
column_id: Some(field.column_id()).filter(|n| *n != 0),
114+
r#type: type_,
115+
};
116+
columns.push(column);
117+
}
118+
}
119+
120+
Ok(Self::Final(columns))
121+
}
122+
}
123+
124+
pub async fn query_with_param_types<'a, P, I>(
125+
client: &Arc<InnerClient>,
126+
query: &str,
127+
params: I,
128+
) -> Result<RowStream, Error>
129+
where
130+
P: BorrowToSql,
131+
I: IntoIterator<Item = (P, Type)>,
132+
I::IntoIter: ExactSizeIterator,
133+
{
134+
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();
135+
136+
let params = params.into_iter();
137+
138+
let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();
139+
140+
let params = params.into_iter();
141+
142+
let buf = client.with_buf(|buf| {
143+
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
144+
145+
encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;
146+
147+
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
148+
149+
frontend::execute("", 0, buf).map_err(Error::encode)?;
150+
151+
frontend::sync(buf);
152+
153+
Ok(buf.split().freeze())
154+
})?;
155+
156+
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
157+
158+
let mut state = QueryProcessingState::Empty;
159+
160+
loop {
161+
let message = responses.next().await?;
162+
163+
state = state.process_message(client, message).await?;
164+
165+
if let QueryProcessingState::Final(columns) = state {
166+
return Ok(RowStream {
167+
statement: Statement::unnamed(vec![], columns),
168+
responses,
169+
rows_affected: None,
170+
_p: PhantomPinned,
171+
});
172+
}
173+
}
174+
}
175+
60176
pub async fn query_portal(
61177
client: &InnerClient,
62178
portal: &Portal,
@@ -164,7 +280,27 @@ where
164280
I: IntoIterator<Item = P>,
165281
I::IntoIter: ExactSizeIterator,
166282
{
167-
let param_types = statement.params();
283+
encode_bind_with_statement_name_and_param_types(
284+
statement.name(),
285+
statement.params(),
286+
params,
287+
portal,
288+
buf,
289+
)
290+
}
291+
292+
fn encode_bind_with_statement_name_and_param_types<P, I>(
293+
statement_name: &str,
294+
param_types: &[Type],
295+
params: I,
296+
portal: &str,
297+
buf: &mut BytesMut,
298+
) -> Result<(), Error>
299+
where
300+
P: BorrowToSql,
301+
I: IntoIterator<Item = P>,
302+
I::IntoIter: ExactSizeIterator,
303+
{
168304
let params = params.into_iter();
169305

170306
if param_types.len() != params.len() {
@@ -181,7 +317,7 @@ where
181317
let mut error_idx = 0;
182318
let r = frontend::bind(
183319
portal,
184-
statement.name(),
320+
statement_name,
185321
param_formats,
186322
params.zip(param_types).enumerate(),
187323
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {

tokio-postgres/src/statement.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ struct StatementInner {
1414

1515
impl Drop for StatementInner {
1616
fn drop(&mut self) {
17+
if self.name.is_empty() {
18+
// Unnamed statements don't need to be closed
19+
return;
20+
}
1721
if let Some(client) = self.client.upgrade() {
1822
let buf = client.with_buf(|buf| {
1923
frontend::close(b'S', &self.name, buf).unwrap();
@@ -46,6 +50,15 @@ impl Statement {
4650
}))
4751
}
4852

53+
pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
54+
Statement(Arc::new(StatementInner {
55+
client: Weak::new(),
56+
name: String::new(),
57+
params,
58+
columns,
59+
}))
60+
}
61+
4962
pub(crate) fn name(&self) -> &str {
5063
&self.0.name
5164
}

0 commit comments

Comments
 (0)