Skip to content

Commit 2e9b5f1

Browse files
kelvichpetuhovskiy
andauthored
Add text protocol based query method (#14)
Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky <petuhovskiy@yandex.ru>
1 parent 0bc41d8 commit 2e9b5f1

File tree

14 files changed

+293
-13
lines changed

14 files changed

+293
-13
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
- run: docker compose up -d
5858
- uses: sfackler/actions/rustup@master
5959
with:
60-
version: 1.62.0
60+
version: 1.65.0
6161
- run: echo "::set-output name=version::$(rustc --version)"
6262
id: rust-version
6363
- uses: actions/cache@v1

postgres-derive-test/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ where
1414
T: PartialEq + FromSqlOwned + ToSql + Sync,
1515
S: fmt::Display,
1616
{
17-
for &(ref val, ref repr) in checks.iter() {
17+
for (val, repr) in checks.iter() {
1818
let stmt = conn
1919
.prepare(&format!("SELECT {}::{}", *repr, sql_type))
2020
.unwrap();
@@ -38,7 +38,7 @@ pub fn test_type_asymmetric<T, F, S, C>(
3838
S: fmt::Display,
3939
C: Fn(&T, &F) -> bool,
4040
{
41-
for &(ref val, ref repr) in checks.iter() {
41+
for (val, repr) in checks.iter() {
4242
let stmt = conn
4343
.prepare(&format!("SELECT {}::{}", *repr, sql_type))
4444
.unwrap();

postgres-protocol/src/authentication/sasl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ impl<'a> Parser<'a> {
389389
}
390390

391391
fn posit_number(&mut self) -> io::Result<u32> {
392-
let n = self.take_while(|c| matches!(c, '0'..='9'))?;
392+
let n = self.take_while(|c| c.is_ascii_digit())?;
393393
n.parse()
394394
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
395395
}

postgres-protocol/src/message/backend.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
707707
));
708708
}
709709
let base = self.len - self.buf.len();
710-
self.buf = &self.buf[len as usize..];
710+
self.buf = &self.buf[len..];
711711
Ok(Some(Some(base..base + len)))
712712
}
713713
}

postgres-types/src/lib.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ impl WrongType {
395395
}
396396
}
397397

398+
/// An error indicating that a as_text conversion was attempted on a binary
399+
/// result.
400+
#[derive(Debug)]
401+
pub struct WrongFormat {}
402+
403+
impl Error for WrongFormat {}
404+
405+
impl fmt::Display for WrongFormat {
406+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
407+
write!(
408+
fmt,
409+
"cannot read column as text while it is in binary format"
410+
)
411+
}
412+
}
413+
398414
/// A trait for types that can be created from a Postgres value.
399415
///
400416
/// # Types
@@ -846,7 +862,7 @@ pub trait ToSql: fmt::Debug {
846862
/// Supported Postgres message format types
847863
///
848864
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
849-
#[derive(Clone, Copy, Debug)]
865+
#[derive(Clone, Copy, Debug, PartialEq)]
850866
pub enum Format {
851867
/// Text format (UTF-8)
852868
Text,

tokio-postgres/src/client.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ use crate::copy_both::CopyBothDuplex;
77
use crate::copy_out::CopyOutStream;
88
#[cfg(feature = "runtime")]
99
use crate::keepalive::KeepaliveConfig;
10+
use crate::prepare::get_type;
1011
use crate::query::RowStream;
1112
use crate::simple_query::SimpleQueryStream;
13+
use crate::statement::Column;
1214
#[cfg(feature = "runtime")]
1315
use crate::tls::MakeTlsConnect;
1416
use crate::tls::TlsConnect;
@@ -20,7 +22,7 @@ use crate::{
2022
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
2123
TransactionBuilder,
2224
};
23-
use bytes::{Buf, BytesMut};
25+
use bytes::{Buf, BufMut, BytesMut};
2426
use fallible_iterator::FallibleIterator;
2527
use futures_channel::mpsc;
2628
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
@@ -374,6 +376,87 @@ impl Client {
374376
query::query(&self.inner, statement, params).await
375377
}
376378

379+
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
380+
/// to save a roundtrip
381+
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
382+
where
383+
S: AsRef<str>,
384+
I: IntoIterator<Item = S>,
385+
I::IntoIter: ExactSizeIterator,
386+
{
387+
let params = params.into_iter();
388+
let params_len = params.len();
389+
390+
let buf = self.inner.with_buf(|buf| {
391+
// Parse, anonymous portal
392+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
393+
// Bind, pass params as text, retrieve as binary
394+
match frontend::bind(
395+
"", // empty string selects the unnamed portal
396+
"", // empty string selects the unnamed prepared statement
397+
std::iter::empty(), // all parameters use the default format (text)
398+
params,
399+
|param, buf| {
400+
buf.put_slice(param.as_ref().as_bytes());
401+
Ok(postgres_protocol::IsNull::No)
402+
},
403+
Some(0), // all text
404+
buf,
405+
) {
406+
Ok(()) => Ok(()),
407+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
408+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
409+
}?;
410+
411+
// Describe portal to typecast results
412+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
413+
// Execute
414+
frontend::execute("", 0, buf).map_err(Error::encode)?;
415+
// Sync
416+
frontend::sync(buf);
417+
418+
Ok(buf.split().freeze())
419+
})?;
420+
421+
let mut responses = self
422+
.inner
423+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
424+
425+
// now read the responses
426+
427+
match responses.next().await? {
428+
Message::ParseComplete => {}
429+
_ => return Err(Error::unexpected_message()),
430+
}
431+
match responses.next().await? {
432+
Message::BindComplete => {}
433+
_ => return Err(Error::unexpected_message()),
434+
}
435+
let row_description = match responses.next().await? {
436+
Message::RowDescription(body) => Some(body),
437+
Message::NoData => None,
438+
_ => return Err(Error::unexpected_message()),
439+
};
440+
441+
// construct statement object
442+
443+
let parameters = vec![Type::UNKNOWN; params_len];
444+
445+
let mut columns = vec![];
446+
if let Some(row_description) = row_description {
447+
let mut it = row_description.fields();
448+
while let Some(field) = it.next().map_err(Error::parse)? {
449+
let type_ = get_type(&self.inner, field.type_oid()).await?;
450+
let column = Column::new(field.name().to_string(), type_);
451+
columns.push(column);
452+
}
453+
}
454+
455+
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
456+
457+
Ok(RowStream::new(statement, responses))
458+
}
459+
377460
/// Executes a statement, returning the number of rows modified.
378461
///
379462
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/codec.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
3535
}
3636
}
3737

38-
pub struct PostgresCodec;
38+
pub struct PostgresCodec {
39+
pub max_message_size: Option<usize>,
40+
}
3941

4042
impl Encoder<FrontendMessage> for PostgresCodec {
4143
type Error = io::Error;
@@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
6466
break;
6567
}
6668

69+
if let Some(max) = self.max_message_size {
70+
if len > max {
71+
return Err(io::Error::new(
72+
io::ErrorKind::InvalidInput,
73+
"message too large",
74+
));
75+
}
76+
}
77+
6778
match header.tag() {
6879
backend::NOTICE_RESPONSE_TAG
6980
| backend::NOTIFICATION_RESPONSE_TAG

tokio-postgres/src/config.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ pub struct Config {
185185
pub(crate) target_session_attrs: TargetSessionAttrs,
186186
pub(crate) channel_binding: ChannelBinding,
187187
pub(crate) replication_mode: Option<ReplicationMode>,
188+
pub(crate) max_backend_message_size: Option<usize>,
188189
}
189190

190191
impl Default for Config {
@@ -217,6 +218,7 @@ impl Config {
217218
target_session_attrs: TargetSessionAttrs::Any,
218219
channel_binding: ChannelBinding::Prefer,
219220
replication_mode: None,
221+
max_backend_message_size: None,
220222
}
221223
}
222224

@@ -472,6 +474,17 @@ impl Config {
472474
self.replication_mode
473475
}
474476

477+
/// Set limit for backend messages size.
478+
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
479+
self.max_backend_message_size = Some(max_backend_message_size);
480+
self
481+
}
482+
483+
/// Get limit for backend messages size.
484+
pub fn get_max_backend_message_size(&self) -> Option<usize> {
485+
self.max_backend_message_size
486+
}
487+
475488
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
476489
match key {
477490
"user" => {
@@ -586,6 +599,14 @@ impl Config {
586599
self.replication_mode(mode);
587600
}
588601
}
602+
"max_backend_message_size" => {
603+
let limit = value.parse::<usize>().map_err(|_| {
604+
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
605+
})?;
606+
if limit > 0 {
607+
self.max_backend_message_size(limit);
608+
}
609+
}
589610
key => {
590611
return Err(Error::config_parse(Box::new(UnknownOption(
591612
key.to_string(),

tokio-postgres/src/connect_raw.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ where
9090
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
9191

9292
let mut stream = StartupStream {
93-
inner: Framed::new(stream, PostgresCodec),
93+
inner: Framed::new(
94+
stream,
95+
PostgresCodec {
96+
max_message_size: config.max_backend_message_size,
97+
},
98+
),
9499
buf: BackendMessages::empty(),
95100
delayed: VecDeque::new(),
96101
};

tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
126126
})
127127
}
128128

129-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
129+
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130130
if let Some(type_) = Type::from_oid(oid) {
131131
return Ok(type_);
132132
}

tokio-postgres/src/query.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ where
5252
Ok(RowStream {
5353
statement,
5454
responses,
55+
command_tag: None,
5556
_p: PhantomPinned,
5657
})
5758
}
@@ -72,6 +73,7 @@ pub async fn query_portal(
7273
Ok(RowStream {
7374
statement: portal.statement().clone(),
7475
responses,
76+
command_tag: None,
7577
_p: PhantomPinned,
7678
})
7779
}
@@ -202,11 +204,24 @@ pin_project! {
202204
pub struct RowStream {
203205
statement: Statement,
204206
responses: Responses,
207+
command_tag: Option<String>,
205208
#[pin]
206209
_p: PhantomPinned,
207210
}
208211
}
209212

213+
impl RowStream {
214+
/// Creates a new `RowStream`.
215+
pub fn new(statement: Statement, responses: Responses) -> Self {
216+
RowStream {
217+
statement,
218+
responses,
219+
command_tag: None,
220+
_p: PhantomPinned,
221+
}
222+
}
223+
}
224+
210225
impl Stream for RowStream {
211226
type Item = Result<Row, Error>;
212227

@@ -217,12 +232,24 @@ impl Stream for RowStream {
217232
Message::DataRow(body) => {
218233
return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
219234
}
220-
Message::EmptyQueryResponse
221-
| Message::CommandComplete(_)
222-
| Message::PortalSuspended => {}
235+
Message::EmptyQueryResponse | Message::PortalSuspended => {}
236+
Message::CommandComplete(body) => {
237+
if let Ok(tag) = body.tag() {
238+
*this.command_tag = Some(tag.to_string());
239+
}
240+
}
223241
Message::ReadyForQuery(_) => return Poll::Ready(None),
224242
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
225243
}
226244
}
227245
}
228246
}
247+
248+
impl RowStream {
249+
/// Returns the command tag of this query.
250+
///
251+
/// This is only available after the stream has been exhausted.
252+
pub fn command_tag(&self) -> Option<String> {
253+
self.command_tag.clone()
254+
}
255+
}

tokio-postgres/src/row.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType};
77
use crate::{Error, Statement};
88
use fallible_iterator::FallibleIterator;
99
use postgres_protocol::message::backend::DataRowBody;
10+
use postgres_types::{Format, WrongFormat};
1011
use std::fmt;
1112
use std::ops::Range;
1213
use std::str;
@@ -187,6 +188,27 @@ impl Row {
187188
let range = self.ranges[idx].to_owned()?;
188189
Some(&self.body.buffer()[range])
189190
}
191+
192+
/// Interpret the column at the given index as text
193+
///
194+
/// Useful when using query_raw_txt() which sets text transfer mode
195+
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
196+
if self.statement.output_format() == Format::Text {
197+
match self.col_buffer(idx) {
198+
Some(raw) => {
199+
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
200+
}
201+
None => Ok(None),
202+
}
203+
} else {
204+
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
205+
}
206+
}
207+
208+
/// Row byte size
209+
pub fn body_len(&self) -> usize {
210+
self.body.buffer().len()
211+
}
190212
}
191213

192214
impl AsName for SimpleColumn {

0 commit comments

Comments
 (0)