Skip to content

Commit 9011f71

Browse files
authored
add query_raw_txt for transaction (#20)
Signed-off-by: Alex Chi <iskyzh@gmail.com>
1 parent 1aaedab commit 9011f71

File tree

4 files changed

+131
-83
lines changed

4 files changed

+131
-83
lines changed

tokio-postgres/src/client.rs

Lines changed: 4 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ use crate::copy_both::CopyBothDuplex;
77
use crate::copy_out::CopyOutStream;
88
#[cfg(feature = "runtime")]
99
use crate::keepalive::KeepaliveConfig;
10-
use crate::prepare::get_type;
1110
use crate::query::RowStream;
1211
use crate::simple_query::SimpleQueryStream;
13-
use crate::statement::Column;
1412
#[cfg(feature = "runtime")]
1513
use crate::tls::MakeTlsConnect;
1614
use crate::tls::TlsConnect;
@@ -22,7 +20,7 @@ use crate::{
2220
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
2321
TransactionBuilder,
2422
};
25-
use bytes::{Buf, BufMut, BytesMut};
23+
use bytes::{Buf, BytesMut};
2624
use fallible_iterator::FallibleIterator;
2725
use futures_channel::mpsc;
2826
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
@@ -380,86 +378,11 @@ impl Client {
380378
/// to save a roundtrip
381379
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
382380
where
383-
S: AsRef<str>,
381+
S: AsRef<str> + Sync + Send,
384382
I: IntoIterator<Item = Option<S>>,
385-
I::IntoIter: ExactSizeIterator,
383+
I::IntoIter: ExactSizeIterator + Sync + Send,
386384
{
387-
let params = params.into_iter();
388-
let params_len = params.len();
389-
390-
let buf = self.inner.with_buf(|buf| {
391-
// Parse, anonymous portal
392-
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
393-
// Bind, pass params as text, retrieve as binary
394-
match frontend::bind(
395-
"", // empty string selects the unnamed portal
396-
"", // empty string selects the unnamed prepared statement
397-
std::iter::empty(), // all parameters use the default format (text)
398-
params,
399-
|param, buf| match param {
400-
Some(param) => {
401-
buf.put_slice(param.as_ref().as_bytes());
402-
Ok(postgres_protocol::IsNull::No)
403-
}
404-
None => Ok(postgres_protocol::IsNull::Yes),
405-
},
406-
Some(0), // all text
407-
buf,
408-
) {
409-
Ok(()) => Ok(()),
410-
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
411-
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
412-
}?;
413-
414-
// Describe portal to typecast results
415-
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
416-
// Execute
417-
frontend::execute("", 0, buf).map_err(Error::encode)?;
418-
// Sync
419-
frontend::sync(buf);
420-
421-
Ok(buf.split().freeze())
422-
})?;
423-
424-
let mut responses = self
425-
.inner
426-
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
427-
428-
// now read the responses
429-
430-
match responses.next().await? {
431-
Message::ParseComplete => {}
432-
_ => return Err(Error::unexpected_message()),
433-
}
434-
match responses.next().await? {
435-
Message::BindComplete => {}
436-
_ => return Err(Error::unexpected_message()),
437-
}
438-
let row_description = match responses.next().await? {
439-
Message::RowDescription(body) => Some(body),
440-
Message::NoData => None,
441-
_ => return Err(Error::unexpected_message()),
442-
};
443-
444-
// construct statement object
445-
446-
let parameters = vec![Type::UNKNOWN; params_len];
447-
448-
let mut columns = vec![];
449-
if let Some(row_description) = row_description {
450-
let mut it = row_description.fields();
451-
while let Some(field) = it.next().map_err(Error::parse)? {
452-
// NB: for some types that function may send a query to the server. At least in
453-
// raw text mode we don't need that info and can skip this.
454-
let type_ = get_type(&self.inner, field.type_oid()).await?;
455-
let column = Column::new(field.name().to_string(), type_, field);
456-
columns.push(column);
457-
}
458-
}
459-
460-
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
461-
462-
Ok(RowStream::new(statement, responses))
385+
query::query_txt(&self.inner, query, params).await
463386
}
464387

465388
/// Executes a statement, returning the number of rows modified.

tokio-postgres/src/generic_client.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed {
5656
I: IntoIterator<Item = P> + Sync + Send,
5757
I::IntoIter: ExactSizeIterator;
5858

59+
/// Like `Client::query_raw_txt`.
60+
async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
61+
where
62+
S: AsRef<str> + Sync + Send,
63+
I: IntoIterator<Item = Option<S>> + Sync + Send,
64+
I::IntoIter: ExactSizeIterator + Sync + Send;
65+
5966
/// Like `Client::prepare`.
6067
async fn prepare(&self, query: &str) -> Result<Statement, Error>;
6168

@@ -133,6 +140,15 @@ impl GenericClient for Client {
133140
self.query_raw(statement, params).await
134141
}
135142

143+
async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
144+
where
145+
S: AsRef<str> + Sync + Send,
146+
I: IntoIterator<Item = Option<S>> + Sync + Send,
147+
I::IntoIter: ExactSizeIterator + Sync + Send,
148+
{
149+
self.query_raw_txt(query, params).await
150+
}
151+
136152
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
137153
self.prepare(query).await
138154
}
@@ -215,6 +231,15 @@ impl GenericClient for Transaction<'_> {
215231
self.query_raw(statement, params).await
216232
}
217233

234+
async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
235+
where
236+
S: AsRef<str> + Sync + Send,
237+
I: IntoIterator<Item = Option<S>> + Sync + Send,
238+
I::IntoIter: ExactSizeIterator + Sync + Send,
239+
{
240+
self.query_raw_txt(query, params).await
241+
}
242+
218243
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
219244
self.prepare(query).await
220245
}

tokio-postgres/src/query.rs

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
4+
use crate::prepare::get_type;
45
use crate::types::{BorrowToSql, IsNull};
5-
use crate::{Error, Portal, Row, Statement};
6-
use bytes::{Bytes, BytesMut};
6+
use crate::{Column, Error, Portal, Row, Statement};
7+
use bytes::{BufMut, Bytes, BytesMut};
8+
use fallible_iterator::FallibleIterator;
79
use futures_util::{ready, Stream};
810
use log::{debug, log_enabled, Level};
911
use pin_project_lite::pin_project;
1012
use postgres_protocol::message::backend::Message;
1113
use postgres_protocol::message::frontend;
14+
use postgres_types::Type;
1215
use std::fmt;
1316
use std::marker::PhantomPinned;
1417
use std::pin::Pin;
18+
use std::sync::Arc;
1519
use std::task::{Context, Poll};
1620

1721
struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
@@ -57,6 +61,92 @@ where
5761
})
5862
}
5963

64+
pub async fn query_txt<S, I>(
65+
client: &Arc<InnerClient>,
66+
query: S,
67+
params: I,
68+
) -> Result<RowStream, Error>
69+
where
70+
S: AsRef<str> + Sync + Send,
71+
I: IntoIterator<Item = Option<S>>,
72+
I::IntoIter: ExactSizeIterator,
73+
{
74+
let params = params.into_iter();
75+
let params_len = params.len();
76+
77+
let buf = client.with_buf(|buf| {
78+
// Parse, anonymous portal
79+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
80+
// Bind, pass params as text, retrieve as binary
81+
match frontend::bind(
82+
"", // empty string selects the unnamed portal
83+
"", // empty string selects the unnamed prepared statement
84+
std::iter::empty(), // all parameters use the default format (text)
85+
params,
86+
|param, buf| match param {
87+
Some(param) => {
88+
buf.put_slice(param.as_ref().as_bytes());
89+
Ok(postgres_protocol::IsNull::No)
90+
}
91+
None => Ok(postgres_protocol::IsNull::Yes),
92+
},
93+
Some(0), // all text
94+
buf,
95+
) {
96+
Ok(()) => Ok(()),
97+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
98+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
99+
}?;
100+
101+
// Describe portal to typecast results
102+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
103+
// Execute
104+
frontend::execute("", 0, buf).map_err(Error::encode)?;
105+
// Sync
106+
frontend::sync(buf);
107+
108+
Ok(buf.split().freeze())
109+
})?;
110+
111+
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
112+
113+
// now read the responses
114+
115+
match responses.next().await? {
116+
Message::ParseComplete => {}
117+
_ => return Err(Error::unexpected_message()),
118+
}
119+
match responses.next().await? {
120+
Message::BindComplete => {}
121+
_ => return Err(Error::unexpected_message()),
122+
}
123+
let row_description = match responses.next().await? {
124+
Message::RowDescription(body) => Some(body),
125+
Message::NoData => None,
126+
_ => return Err(Error::unexpected_message()),
127+
};
128+
129+
// construct statement object
130+
131+
let parameters = vec![Type::UNKNOWN; params_len];
132+
133+
let mut columns = vec![];
134+
if let Some(row_description) = row_description {
135+
let mut it = row_description.fields();
136+
while let Some(field) = it.next().map_err(Error::parse)? {
137+
// NB: for some types that function may send a query to the server. At least in
138+
// raw text mode we don't need that info and can skip this.
139+
let type_ = get_type(client, field.type_oid()).await?;
140+
let column = Column::new(field.name().to_string(), type_, field);
141+
columns.push(column);
142+
}
143+
}
144+
145+
let statement = Statement::new_text(client, "".to_owned(), parameters, columns);
146+
147+
Ok(RowStream::new(statement, responses))
148+
}
149+
60150
pub async fn query_portal(
61151
client: &InnerClient,
62152
portal: &Portal,

tokio-postgres/src/transaction.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ impl<'a> Transaction<'a> {
149149
self.client.query_raw(statement, params).await
150150
}
151151

152+
/// Like `Client::query_raw_txt`.
153+
pub async fn query_raw_txt<S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
154+
where
155+
S: AsRef<str> + Sync + Send,
156+
I: IntoIterator<Item = Option<S>>,
157+
I::IntoIter: ExactSizeIterator + Sync + Send,
158+
{
159+
self.client.query_raw_txt(query, params).await
160+
}
161+
152162
/// Like `Client::execute`.
153163
pub async fn execute<T>(
154164
&self,

0 commit comments

Comments
 (0)