From d0cbcc85041db97800fe2e5be1aac13a9b2e4313 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Fri, 15 Apr 2022 21:06:01 +0300 Subject: [PATCH 01/22] Enable CI for `neon` branch We'd like to check our patches. --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 008158fb0..4df62ebbc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,11 @@ name: CI on: pull_request: branches: + - neon - master push: branches: + - neon - master env: From 8e2d19f88ddd359115a193996098e4de0d564f95 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Mon, 14 Dec 2020 12:33:29 -0800 Subject: [PATCH 02/22] Support for physical and logical replication This patch was implemented by Petros Angelatos and Jeff Davis to support physical and logical replication in rust-postgres (see https://github.com/sfackler/rust-postgres/pull/752). The original PR never made it to the upstream, but we (Neon) still use it in our own fork of rust-postgres. The following commits were squashed together: * Image configuration updates. * Make simple_query::encode() pub(crate). * decoding logic for replication protocol * Connection string config for replication. * add copy_both_simple method * helper ReplicationStream type for replication protocol This can be optionally used with a CopyBoth stream to decode the replication protocol * decoding logic for logical replication protocol * helper LogicalReplicationStream type to decode logical replication * add postgres replication integration test * add simple query versions of copy operations * replication: use SystemTime for timestamps at API boundary Co-authored-by: Petros Angelatos Co-authored-by: Jeff Davis Co-authored-by: Dmitry Ivanov --- docker/sql_setup.sh | 2 + flake.lock | 61 ++ flake.nix | 23 + postgres-protocol/Cargo.toml | 1 + postgres-protocol/src/lib.rs | 7 + postgres-protocol/src/message/backend.rs | 776 ++++++++++++++++++++++- tokio-postgres/src/client.rs | 28 +- tokio-postgres/src/config.rs | 35 + tokio-postgres/src/connect_raw.rs | 8 +- tokio-postgres/src/connection.rs | 20 + tokio-postgres/src/copy_both.rs | 248 ++++++++ tokio-postgres/src/copy_in.rs | 38 +- tokio-postgres/src/copy_out.rs | 25 +- tokio-postgres/src/lib.rs | 2 + tokio-postgres/src/replication.rs | 184 ++++++ tokio-postgres/src/simple_query.rs | 2 +- tokio-postgres/tests/test/main.rs | 2 + tokio-postgres/tests/test/replication.rs | 146 +++++ 18 files changed, 1587 insertions(+), 21 deletions(-) create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 tokio-postgres/src/copy_both.rs create mode 100644 tokio-postgres/src/replication.rs create mode 100644 tokio-postgres/tests/test/replication.rs diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/flake.lock b/flake.lock new file mode 100644 index 000000000..919cbc19e --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1701074022, + "narHash": "sha256-yJvhSs+AswFyVsacGdrl+ASakN7ZNcC12Fh2Hp9lcXs=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "bbb4adee2f9bb15d85ea0670db12f85ba85c94d3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "master", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 000000000..b400fb2b8 --- /dev/null +++ b/flake.nix @@ -0,0 +1,23 @@ +{ + description = "A prisma test project"; + inputs.nixpkgs.url = "github:NixOS/nixpkgs/master"; + inputs.flake-utils.url = "github:numtide/flake-utils"; + + outputs = { + self, + nixpkgs, + flake-utils, + }: + flake-utils.lib.eachDefaultSystem (system: let + pkgs = nixpkgs.legacyPackages.${system}; + in { + devShell = pkgs.mkShell { + nativeBuildInputs = [pkgs.bashInteractive]; + buildInputs = with pkgs; [ + openssl + rustup + pkg-config + ]; + }; + }); +} diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index b44994811..a559c97c3 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -18,6 +18,7 @@ byteorder = "1.0" bytes = "1.0" fallible-iterator = "0.2" hmac = "0.12" +lazy_static = "1.4" md-5 = "0.10" memchr = "2.0" rand = "0.8" diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 83d9bf55c..f65ed84f3 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -13,7 +13,9 @@ use byteorder::{BigEndian, ByteOrder}; use bytes::{BufMut, BytesMut}; +use lazy_static::lazy_static; use std::io; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub mod authentication; pub mod escape; @@ -27,6 +29,11 @@ pub type Oid = u32; /// A Postgres Log Sequence Number (LSN). pub type Lsn = u64; +lazy_static! { + /// Postgres epoch is 2000-01-01T00:00:00Z + pub static ref PG_EPOCH: SystemTime = UNIX_EPOCH + Duration::from_secs(946_684_800); +} + /// An enum indicating if a value is `NULL` or not. pub enum IsNull { /// The value is `NULL`. diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 1b5be1098..42534d2f8 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -8,9 +8,11 @@ use std::cmp; use std::io::{self, Read}; use std::ops::Range; use std::str; +use std::time::{Duration, SystemTime}; -use crate::Oid; +use crate::{Lsn, Oid, PG_EPOCH}; +// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -22,6 +24,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -33,6 +36,33 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -93,6 +123,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +221,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -278,6 +319,69 @@ impl Message { } } +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + struct Buffer { bytes: Bytes, idx: usize, @@ -524,6 +628,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + storage: Bytes, + len: u16, + format: u8, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug)] pub struct DataRowBody { storage: Bytes, @@ -782,6 +907,655 @@ impl RowDescriptionBody { } } +#[derive(Debug)] +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: SystemTime, + data: D, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data(self, f: F) -> Result, E> + where + F: Fn(D) -> Result, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: SystemTime, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + xid: buf.read_u32::()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::()?, + end_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result { + let col_len = buf.read_i16::()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::()?, + type_modifier: buf.read_i32::()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 427a05049..6b7067ee8 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -13,8 +14,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -423,6 +425,14 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } + /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. + pub async fn copy_in_simple(&self, query: &str) -> Result, Error> + where + U: Buf + 'static + Send, + { + copy_in::copy_in_simple(self.inner(), query).await + } + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -434,6 +444,20 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. + pub async fn copy_out_simple(&self, query: &str) -> Result { + copy_out::copy_out_simple(self.inner(), query).await + } + + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b178eac80..0f2d2e748 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -70,6 +70,16 @@ pub enum LoadBalanceHosts { Random, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -207,6 +217,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -240,6 +251,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, } } @@ -520,6 +532,17 @@ impl Config { self.load_balance_hosts } + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -655,6 +678,17 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -739,6 +773,7 @@ impl fmt::Debug for Config { config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..8edf45937 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -133,6 +133,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..a3449f88b 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -258,6 +260,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..79a7be34a --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,248 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) enum CopyBothMessage { + Message(FrontendMessage), + Done, +} + +pub struct CopyBothReceiver { + receiver: mpsc::Receiver, + done: bool, +} + +impl CopyBothReceiver { + pub(crate) fn new(receiver: mpsc::Receiver) -> CopyBothReceiver { + CopyBothReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), + Some(CopyBothMessage::Done) => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyBothDuplex { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl CopyBothDuplex +where + T: Buf + 'static + Send, +{ + pub(crate) fn new(sender: mpsc::Sender, responses: Responses) -> Self { + Self { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + } + } + + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyBothMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (mut sender, receiver) = mpsc::channel(1); + let receiver = CopyBothReceiver::new(receiver); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + sender + .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex::new(sender, responses)) +} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index 59e31fea6..b3fdba84a 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -2,8 +2,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::query::extract_row_affected; -use crate::{query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, BytesMut}; +use crate::{query, simple_query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -188,14 +188,10 @@ where } } -pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result, Error> where T: Buf + 'static + Send, { - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -205,9 +201,11 @@ where .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { @@ -224,3 +222,23 @@ where _p2: PhantomData, }) } + +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + start(client, buf, false).await +} + +pub async fn copy_in_simple(client: &InnerClient, query: &str) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in query {}", query); + + let buf = simple_query::encode(client, query)?; + start(client, buf, true).await +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..981f9365e 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, simple_query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,23 +11,36 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; +pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result { + debug!("executing copy out query {}", query); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf, true).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf).await?; + let responses = start(client, buf, false).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 2973d33b0..c221454a6 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -158,6 +158,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; @@ -168,6 +169,7 @@ mod maybe_tls_stream; mod portal; mod prepare; mod query; +pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs new file mode 100644 index 000000000..7e67de0d6 --- /dev/null +++ b/tokio-postgres/src/replication.rs @@ -0,0 +1,184 @@ +//! Utilities for working with the PostgreSQL replication copy both format. + +use crate::copy_both::CopyBothDuplex; +use crate::Error; +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::{LogicalReplicationMessage, ReplicationMessage}; +use postgres_protocol::PG_EPOCH; +use postgres_types::PgLsn; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; +const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; + +pin_project! { + /// A type which deserializes the postgres replication protocol. This type can be used with + /// both physical and logical replication to get access to the byte content of each replication + /// message. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct ReplicationStream { + #[pin] + stream: CopyBothDuplex, + } +} + +impl ReplicationStream { + /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { stream } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(STANDBY_STATUS_UPDATE_TAG); + buf.put_u64(write_lsn.into()); + buf.put_u64(flush_lsn.into()); + buf.put_u64(apply_lsn.into()); + buf.put_i64(timestamp); + buf.put_u8(reply); + + this.stream.send(buf.freeze()).await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + + this.stream.send(buf.freeze()).await + } +} + +impl Stream for ReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(buf)) => { + Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +pin_project! { + /// A type which deserializes the postgres logical replication protocol. This type gives access + /// to a high level representation of the changes in transaction commit order. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct LogicalReplicationStream { + #[pin] + stream: ReplicationStream, + } +} + +impl LogicalReplicationStream { + /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { + stream: ReplicationStream::new(stream), + } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .standby_status_update(write_lsn, flush_lsn, apply_lsn, timestamp, reply) + .await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .hot_standby_feedback( + timestamp, + global_xmin, + global_xmin_epoch, + catalog_xmin, + catalog_xmin_epoch, + ) + .await + } +} + +impl Stream for LogicalReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(ReplicationMessage::XLogData(body))) => { + let body = body + .map_data(|buf| LogicalReplicationMessage::parse(&buf)) + .map_err(Error::parse)?; + Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) + } + Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { + Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index bcc6d928b..a97ee126c 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..8de2b75a2 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -22,6 +22,8 @@ use tokio_postgres::{ mod binary_copy; mod parse; #[cfg(feature = "runtime")] +mod replication; +#[cfg(feature = "runtime")] mod runtime; mod types; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs new file mode 100644 index 000000000..c176a4104 --- /dev/null +++ b/tokio-postgres/tests/test/replication.rs @@ -0,0 +1,146 @@ +use futures_util::StreamExt; +use std::time::SystemTime; + +use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_protocol::message::backend::ReplicationMessage::*; +use postgres_protocol::message::backend::TupleData; +use postgres_types::PgLsn; +use tokio_postgres::replication::LogicalReplicationStream; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +#[tokio::test] +async fn test_replication() { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS test_logical_replication") + .await + .unwrap(); + client + .simple_query("CREATE TABLE test_logical_replication(i int)") + .await + .unwrap(); + let res = client + .simple_query("SELECT 'test_logical_replication'::regclass::oid") + .await + .unwrap(); + let rel_id: u32 = if let Row(row) = &res[0] { + row.get("oid").unwrap().parse().unwrap() + } else { + panic!("unexpeced query message"); + }; + + client + .simple_query("DROP PUBLICATION IF EXISTS test_pub") + .await + .unwrap(); + client + .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") + .await + .unwrap(); + + let slot = "test_logical_slot"; + + let query = format!( + r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, + slot + ); + let slot_query = client.simple_query(&query).await.unwrap(); + let lsn = if let Row(row) = &slot_query[0] { + row.get("consistent_point").unwrap() + } else { + panic!("unexpeced query message"); + }; + + // issue a query that will appear in the slot's stream since it happened after its creation + client + .simple_query("INSERT INTO test_logical_replication VALUES (42)") + .await + .unwrap(); + + let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; + let query = format!( + r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, + slot, lsn, options + ); + let copy_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let stream = LogicalReplicationStream::new(copy_stream); + tokio::pin!(stream); + + // verify that we can observe the transaction in the replication stream + let begin = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Begin(begin) = body.into_data() { + break begin; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + let insert = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Insert(insert) = body.into_data() { + break insert; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let commit = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Commit(commit) = body.into_data() { + break commit; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + assert_eq!(begin.final_lsn(), commit.commit_lsn()); + assert_eq!(insert.rel_id(), rel_id); + + let tuple_data = insert.tuple().tuple_data(); + assert_eq!(tuple_data.len(), 1); + assert!(matches!(tuple_data[0], TupleData::Text(_))); + if let TupleData::Text(data) = &tuple_data[0] { + assert_eq!(data, &b"42"[..]); + } + + // Send a standby status update and require a keep alive response + let lsn: PgLsn = lsn.parse().unwrap(); + stream + .as_mut() + .standby_status_update(lsn, lsn, lsn, SystemTime::now(), 1) + .await + .unwrap(); + loop { + match stream.next().await { + Some(Ok(PrimaryKeepAlive(_))) => break, + Some(Ok(_)) => (), + Some(Err(e)) => panic!("unexpected replication stream error: {}", e), + None => panic!("unexpected replication stream end"), + } + } +} From 19a7726924f5beb34218343aa57dae41c8630e2e Mon Sep 17 00:00:00 2001 From: anastasia Date: Mon, 20 Dec 2021 22:11:37 +0300 Subject: [PATCH 03/22] Extend replication protocol with ZenithStatusUpdate message --- tokio-postgres/src/replication.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs index 7e67de0d6..e7c845958 100644 --- a/tokio-postgres/src/replication.rs +++ b/tokio-postgres/src/replication.rs @@ -14,6 +14,7 @@ use std::time::SystemTime; const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; +const ZENITH_STATUS_UPDATE_TAG_BYTE: u8 = b'z'; pin_project! { /// A type which deserializes the postgres replication protocol. This type can be used with @@ -33,6 +34,22 @@ impl ReplicationStream { Self { stream } } + /// Send zenith status update to server. + pub async fn zenith_status_update( + self: Pin<&mut Self>, + len: u64, + data: &[u8], + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(ZENITH_STATUS_UPDATE_TAG_BYTE); + buf.put_u64(len); + buf.put_slice(data); + + this.stream.send(buf.freeze()).await + } + /// Send standby update to server. pub async fn standby_status_update( self: Pin<&mut Self>, From 9a437283e164a9f1edcb12cac2e28f5157ec10a3 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Tue, 19 Apr 2022 16:39:05 +0300 Subject: [PATCH 04/22] Allow passing precomputed SCRAM keys via Config According to https://datatracker.ietf.org/doc/html/rfc5802#section-3, SCRAM protocol explicitly allows client to use a `ClientKey` & `ServerKey` pair instead of a password to perform authentication. This is also useful for proxy implementations which would like to leverage `rust-postgres`. This patch adds the ability to do that. --- postgres-protocol/src/authentication/sasl.rs | 110 +++++++++++++------ postgres/src/config.rs | 17 ++- tokio-postgres/src/config.rs | 25 +++++ tokio-postgres/src/connect_raw.rs | 15 +-- 4 files changed, 123 insertions(+), 44 deletions(-) diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index 4a77507e9..ac18ed8a8 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -99,14 +99,32 @@ impl ChannelBinding { } } +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + enum State { Update { nonce: String, - password: Vec, + password: Credentials<32>, channel_binding: ChannelBinding, }, Finish { - salted_password: [u8; 32], + server_key: [u8; 32], auth_message: String, }, Done, @@ -132,30 +150,43 @@ pub struct ScramSha256 { state: State, } +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + impl ScramSha256 { /// Constructs a new instance which will use the provided password for authentication. pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { - // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); - let nonce = (0..NONCE_LENGTH) - .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); - if v == 0x2c { - v = 0x7e - } - v as char - }) - .collect::(); + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } - ScramSha256::new_inner(password, channel_binding, nonce) + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) } - fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 { + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { ScramSha256 { message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), state: State::Update { nonce, - password: normalize(password), + password, channel_binding, }, } @@ -192,20 +223,32 @@ impl ScramSha256 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); } - let salt = match STANDARD.decode(parsed.salt) { - Ok(salt) => salt, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; - let salted_password = hi(&password, &salt, parsed.iteration_count); + let salted_password = hi(&password, &salt, parsed.iteration_count); - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Client Key"); - let client_key = hmac.finalize().into_bytes(); + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; let mut hash = Sha256::default(); - hash.update(client_key.as_slice()); + hash.update(client_key); let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; @@ -236,7 +279,7 @@ impl ScramSha256 { .unwrap(); self.state = State::Finish { - salted_password, + server_key, auth_message, }; Ok(()) @@ -247,11 +290,11 @@ impl ScramSha256 { /// This should be called when the backend sends an `AuthenticationSASLFinal` message. /// Authentication has only succeeded if this method returns `Ok(())`. pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { - let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { State::Finish { - salted_password, + server_key, auth_message, - } => (salted_password, auth_message), + } => (server_key, auth_message), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), }; @@ -275,11 +318,6 @@ impl ScramSha256 { Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Server Key"); - let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::::new_from_slice(&server_key) .expect("HMAC is able to accept all key sizes"); hmac.update(auth_message.as_bytes()); @@ -466,7 +504,7 @@ mod test { let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; let mut scram = ScramSha256::new_inner( - password.as_bytes(), + Credentials::Password(normalize(password.as_bytes())), ChannelBinding::unsupported(), nonce.to_string(), ); diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a32ddc78e..b0fb9359a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -10,9 +10,10 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use tokio::runtime; +use tokio_postgres::config::LoadBalanceHosts; #[doc(inline)] pub use tokio_postgres::config::{ - ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs, + AuthKeys, ChannelBinding, Host, ScramKeys, SslMode, TargetSessionAttrs, }; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; @@ -179,6 +180,20 @@ impl Config { self.config.get_password() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.config.auth_keys(keys); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.config.get_auth_keys() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 0f2d2e748..76287d1c5 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -26,6 +26,8 @@ use std::time::Duration; use std::{error, fmt, iter, mem}; use tokio::io::{AsyncRead, AsyncWrite}; +pub use postgres_protocol::authentication::sasl::ScramKeys; + /// Properties required of a session. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -92,6 +94,13 @@ pub enum Host { Unix(PathBuf), } +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + /// Connection configuration. /// /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: @@ -202,6 +211,7 @@ pub enum Host { pub struct Config { pub(crate) user: Option, pub(crate) password: Option>, + pub(crate) auth_keys: Option>, pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, @@ -232,6 +242,7 @@ impl Config { Config { user: None, password: None, + auth_keys: None, dbname: None, options: None, application_name: None, @@ -284,6 +295,20 @@ impl Config { self.password.as_deref() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 8edf45937..671e7917f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config, ReplicationMode}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -243,11 +243,6 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - let password = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - let mut has_scram = false; let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); @@ -285,7 +280,13 @@ where can_skip_channel_binding(config)?; } - let mut scram = ScramSha256::new(password, channel_binding); + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; let mut buf = BytesMut::new(); frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; From e49339d6dc42b0aa880e35ec0eda7e437609efe4 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Thu, 15 Dec 2022 18:03:31 +0300 Subject: [PATCH 05/22] Make tokio-postgres connection parameters public We need this to enable parameter forwarding in Neon Proxy. This is less than ideal, but we'll probably revert the patch once a proper fix has been implemented. --- tokio-postgres/src/connection.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index a3449f88b..3b2833bd3 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -50,7 +50,8 @@ enum State { #[must_use = "futures do nothing unless polled"] pub struct Connection { stream: Framed, PostgresCodec>, - parameters: HashMap, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, pending_responses: VecDeque, From d6a5e8286027e3433a65bac57e9b024b51149f44 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 28 Apr 2023 02:02:44 +0300 Subject: [PATCH 06/22] Expose conection.stream That way our proxy can take back stream for proxying. --- tokio-postgres/src/connection.rs | 3 ++- tokio-postgres/src/lib.rs | 2 +- tokio-postgres/src/maybe_tls_stream.rs | 6 ++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 3b2833bd3..9e58b4e05 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -49,7 +49,8 @@ enum State { /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] pub struct Connection { - stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index c221454a6..19c4f80a1 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -165,7 +165,7 @@ pub mod error; mod generic_client; #[cfg(not(target_arch = "wasm32"))] mod keepalive; -mod maybe_tls_stream; +pub mod maybe_tls_stream; mod portal; mod prepare; mod query; diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs index 73b0c4721..9a7e24899 100644 --- a/tokio-postgres/src/maybe_tls_stream.rs +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -1,11 +1,17 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. use crate::tls::{ChannelBinding, TlsStream}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +/// A stream that may or may not be encrypted with TLS. pub enum MaybeTlsStream { + /// An unencrypted stream. Raw(S), + /// An encrypted stream. Tls(T), } From 4eb8036a9a9eabd671d7449ba1978b5d0c38f695 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 23 May 2023 11:32:41 +0300 Subject: [PATCH 07/22] Add text protocol based query method (#14) Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky --- .github/workflows/ci.yml | 4 +- postgres-types/src/lib.rs | 18 ++++++- postgres/src/row_iter.rs | 7 --- tokio-postgres/src/client.rs | 85 ++++++++++++++++++++++++++++++- tokio-postgres/src/codec.rs | 13 ++++- tokio-postgres/src/config.rs | 21 ++++++++ tokio-postgres/src/connect_raw.rs | 7 ++- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/query.rs | 32 ++++++++---- tokio-postgres/src/row.rs | 22 ++++++++ tokio-postgres/src/statement.rs | 23 +++++++++ tokio-postgres/tests/test/main.rs | 72 ++++++++++++++++++++++++++ 12 files changed, 284 insertions(+), 22 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4df62ebbc..8e37690fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,9 @@ jobs: steps: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + with: + version: 1.65.0 + - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown - uses: actions/cache@v3 diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 2f02f6e5f..4f3ae1f11 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -449,6 +449,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -900,7 +916,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/postgres/src/row_iter.rs b/postgres/src/row_iter.rs index 221fdfc68..772e9893c 100644 --- a/postgres/src/row_iter.rs +++ b/postgres/src/row_iter.rs @@ -17,13 +17,6 @@ impl<'a> RowIter<'a> { it: Box::pin(stream), } } - - /// Returns the number of rows affected by the query. - /// - /// This function will return `None` until the iterator has been exhausted. - pub fn rows_affected(&self) -> Option { - self.it.rows_affected() - } } impl FallibleIterator for RowIter<'_> { diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 6b7067ee8..b0ad062d6 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -5,8 +5,10 @@ use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; +use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; +use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -18,7 +20,7 @@ use crate::{ CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -370,6 +372,87 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + let params_len = params.len(); + + let buf = self.inner.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = self + .inner + .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&self.inner, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 76287d1c5..421f7c5aa 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -228,6 +228,7 @@ pub struct Config { pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -263,6 +264,7 @@ impl Config { channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, replication_mode: None, + max_backend_message_size: None, } } @@ -568,6 +570,17 @@ impl Config { self.replication_mode } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -714,6 +727,14 @@ impl Config { self.replication_mode(mode); } } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 671e7917f..19219c8ac 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -92,7 +92,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..ba8d5a43e 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..661f802f6 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -52,7 +52,7 @@ where Ok(RowStream { statement, responses, - rows_affected: None, + command_tag: None, _p: PhantomPinned, }) } @@ -73,7 +73,7 @@ pub async fn query_portal( Ok(RowStream { statement: portal.statement().clone(), responses, - rows_affected: None, + command_tag: None, _p: PhantomPinned, }) } @@ -207,12 +207,24 @@ pin_project! { pub struct RowStream { statement: Statement, responses: Responses, - rows_affected: Option, + command_tag: Option, #[pin] _p: PhantomPinned, } } +impl RowStream { + /// Creates a new `RowStream`. + pub fn new(statement: Statement, responses: Responses) -> Self { + RowStream { + statement, + responses, + command_tag: None, + _p: PhantomPinned, + } + } +} + impl Stream for RowStream { type Item = Result; @@ -223,10 +235,12 @@ impl Stream for RowStream { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) } + Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { - *this.rows_affected = Some(extract_row_affected(&body)?); + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } } - Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::ReadyForQuery(_) => return Poll::Ready(None), _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } @@ -235,10 +249,10 @@ impl Stream for RowStream { } impl RowStream { - /// Returns the number of rows affected by the query. + /// Returns the command tag of this query. /// - /// This function will return `None` until the stream has been exhausted. - pub fn rows_affected(&self) -> Option { - self.rows_affected + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..ce4efed7e 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -187,6 +188,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.statement.output_format() == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b7ab11866 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,6 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -13,6 +14,7 @@ struct StatementInner { name: String, params: Vec, columns: Vec, + output_format: Format, } impl Drop for StatementInner { @@ -46,6 +48,22 @@ impl Statement { name, params, columns, + output_format: Format::Binary, + })) + } + + pub(crate) fn new_text( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + output_format: Format::Text, })) } @@ -62,6 +80,11 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } + + /// Returns output format for the statement. + pub fn output_format(&self) -> Format { + self.0.output_format + } } /// Information about a column of a query. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 8de2b75a2..551f6ec5c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -251,6 +251,78 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From ecb9d7b26e2c6af4d1dca009b98b4c97e8b3e22d Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Thu, 8 Jun 2023 10:58:15 +0300 Subject: [PATCH 08/22] Allow passing null params in query_raw_txt() Previous coding only allowed passing vector of text values as params, but that does not allow to distinguish between nulls and 4-byte strings with "null" written in them. Change query_raw_txt params argument to accept Vec> instead. --- tokio-postgres/src/client.rs | 11 ++++++---- tokio-postgres/tests/test/main.rs | 34 +++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index b0ad062d6..ed7ec7913 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -377,7 +377,7 @@ impl Client { pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result where S: AsRef, - I: IntoIterator, + I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); @@ -392,9 +392,12 @@ impl Client { "", // empty string selects the unnamed prepared statement std::iter::empty(), // all parameters use the default format (text) params, - |param, buf| { - buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol::IsNull::No) + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), }, Some(0), // all text buf, diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 551f6ec5c..b36c51a9f 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -256,7 +256,7 @@ async fn query_raw_txt() { let client = connect("user=postgres").await; let rows: Vec = client - .query_raw_txt("SELECT 55 * $1", ["42"]) + .query_raw_txt("SELECT 55 * $1", [Some("42")]) .await .unwrap() .try_collect() @@ -268,7 +268,7 @@ async fn query_raw_txt() { assert_eq!(res, 55 * 42); let rows: Vec = client - .query_raw_txt("SELECT $1", ["42"]) + .query_raw_txt("SELECT $1", [Some("42")]) .await .unwrap() .try_collect() @@ -280,6 +280,36 @@ async fn query_raw_txt() { assert!(rows[0].body_len() > 0); } +#[tokio::test] +async fn query_raw_txt_nulls() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "SELECT $1 as str, $2 as n, 'null' as str2, null as n2", + [Some("null"), None], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + + let res = rows[0].as_text(0).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(1).unwrap(); + assert_eq!(res, None); + + let res = rows[0].as_text(2).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(3).unwrap(); + assert_eq!(res, None); +} + #[tokio::test] async fn limit_max_backend_message_size() { let client = connect("user=postgres max_backend_message_size=10000").await; From d3ad00dbf23badffaca183283da6a882d5513ed2 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 13 Jun 2023 01:11:08 +0300 Subject: [PATCH 09/22] Return more RowDescription fields As we are trying to match client-side behaviour with node-postgres we need to return this fields as well because node-postgres returns them. --- tokio-postgres/src/client.rs | 4 ++- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/statement.rs | 58 +++++++++++++++++++++++++++++-- tokio-postgres/tests/test/main.rs | 25 +++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ed7ec7913..e9f0bc128 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -445,8 +445,10 @@ impl Client { if let Some(row_description) = row_description { let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { + // NB: for some types that function may send a query to the server. At least in + // raw text mode we don't need that info and can skip this. let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index ba8d5a43e..0abb8e453 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -95,7 +95,7 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index b7ab11866..8743f00f0 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -2,7 +2,10 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; -use postgres_protocol::message::frontend; +use postgres_protocol::{ + message::{backend::Field, frontend}, + Oid, +}; use postgres_types::Format; use std::{ fmt, @@ -91,11 +94,30 @@ impl Statement { pub struct Column { name: String, type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } } /// Returns the name of the column. @@ -107,6 +129,36 @@ impl Column { pub fn type_(&self) -> &Type { &self.type_ } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } } impl fmt::Debug for Column { diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index b36c51a9f..20345c7b4 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -353,6 +353,31 @@ async fn command_tag() { assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); } +#[tokio::test] +async fn column_extras() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("select relacl, relname from pg_class limit 1", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let column = rows[0].columns().get(1).unwrap(); + assert_eq!(column.name(), "relname"); + assert_eq!(column.type_(), &Type::NAME); + + assert!(column.table_oid() > 0); + assert_eq!(column.column_id(), 2); + assert_eq!(column.format(), 0); + + assert_eq!(column.type_oid(), 19); + assert_eq!(column.type_size(), 64); + assert_eq!(column.type_modifier(), -1); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From 580ac24c37cccd5b9d271e2e439fa1c2df394193 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Mon, 24 Jul 2023 15:15:14 -0400 Subject: [PATCH 10/22] add query_raw_txt for transaction (#20) Signed-off-by: Alex Chi --- tokio-postgres/src/client.rs | 85 ++----------------------- tokio-postgres/src/generic_client.rs | 25 ++++++++ tokio-postgres/src/query.rs | 94 +++++++++++++++++++++++++++- tokio-postgres/src/transaction.rs | 10 +++ 4 files changed, 131 insertions(+), 83 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e9f0bc128..c12b7d8f2 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -5,10 +5,8 @@ use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; -use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; -use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -20,7 +18,7 @@ use crate::{ CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -376,86 +374,11 @@ impl Client { /// to save a roundtrip pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result where - S: AsRef, + S: AsRef + Sync + Send, I: IntoIterator>, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Sync + Send, { - let params = params.into_iter(); - let params_len = params.len(); - - let buf = self.inner.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary - match frontend::bind( - "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement - std::iter::empty(), // all parameters use the default format (text) - params, - |param, buf| match param { - Some(param) => { - buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol::IsNull::No) - } - None => Ok(postgres_protocol::IsNull::Yes), - }, - Some(0), // all text - buf, - ) { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - }?; - - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; - // Execute - frontend::execute("", 0, buf).map_err(Error::encode)?; - // Sync - frontend::sync(buf); - - Ok(buf.split().freeze()) - })?; - - let mut responses = self - .inner - .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + query::query_txt(&self.inner, query, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..a259532e5 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -136,6 +143,15 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -222,6 +238,15 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 661f802f6..0f4e2bdcb 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,21 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use crate::{Column, Error, Portal, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; +use postgres_types::Type; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -57,6 +61,92 @@ where }) } +pub async fn query_txt( + client: &Arc, + query: S, + params: I, +) -> Result +where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + let params_len = params.len(); + + let buf = client.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + // NB: for some types that function may send a query to the server. At least in + // raw text mode we don't need that info and can skip this. + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + let statement = Statement::new_text(client, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..806196aa3 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,16 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.client.query_raw_txt(query, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, From 0d8c71600313c01c456b99da03c8ce70ca5904a2 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 11 Aug 2023 15:14:05 +0100 Subject: [PATCH 11/22] Connection changes (#21) * refactor query_raw_txt to use a pre-prepared statement * expose ready_status on RowStream --- .github/workflows/ci.yml | 2 +- tokio-postgres/src/client.rs | 14 ++-- tokio-postgres/src/generic_client.rs | 17 +++-- tokio-postgres/src/query.rs | 98 ++++++++++------------------ tokio-postgres/src/row.rs | 10 ++- tokio-postgres/src/statement.rs | 23 ------- tokio-postgres/src/transaction.rs | 9 +-- tokio-postgres/tests/test/main.rs | 36 ++++++++-- 8 files changed, 104 insertions(+), 105 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8e37690fa..937366a08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master with: - version: 1.65.0 + version: 1.67.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index c12b7d8f2..9a3786ec7 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -372,13 +372,19 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + pub async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, query, params).await + let statement = statement.__convert().into_statement(self).await?; + query::query_txt(&self.inner, statement, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index a259532e5..a4ee4808b 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -57,8 +57,13 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator; /// Like `Client::query_raw_txt`. - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; @@ -143,13 +148,14 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { @@ -238,13 +244,14 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 0f4e2bdcb..7bac62fce 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,15 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Column, Error, Portal, Row, Statement}; +use crate::{Error, Portal, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; -use postgres_types::Type; +use postgres_types::Format; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; @@ -57,30 +55,29 @@ where statement, responses, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } pub async fn query_txt( client: &Arc, - query: S, + statement: Statement, params: I, ) -> Result where - S: AsRef + Sync + Send, + S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); - let params_len = params.len(); let buf = client.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; // Bind, pass params as text, retrieve as binary match frontend::bind( "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement + statement.name(), // named prepared statement std::iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { @@ -98,8 +95,6 @@ where Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), }?; - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; // Execute frontend::execute("", 0, buf).map_err(Error::encode)?; // Sync @@ -108,43 +103,16 @@ where Ok(buf.split().freeze()) })?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(client, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: None, + output_format: Format::Text, + _p: PhantomPinned, + }) } pub async fn query_portal( @@ -164,6 +132,8 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -298,23 +268,13 @@ pin_project! { statement: Statement, responses: Responses, command_tag: Option, + output_format: Format, + status: Option, #[pin] _p: PhantomPinned, } } -impl RowStream { - /// Creates a new `RowStream`. - pub fn new(statement: Statement, responses: Responses) -> Self { - RowStream { - statement, - responses, - command_tag: None, - _p: PhantomPinned, - } - } -} - impl Stream for RowStream { type Item = Result; @@ -323,7 +283,11 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { @@ -331,7 +295,10 @@ impl Stream for RowStream { *this.command_tag = Some(tag.to_string()); } } - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(status) => { + *this.status = Some(status.status()); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } @@ -345,4 +312,11 @@ impl RowStream { pub fn command_tag(&self) -> Option { self.command_tag.clone() } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> Option { + self.status + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index ce4efed7e..754b5f28c 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -98,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -111,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -193,7 +199,7 @@ impl Row { /// /// Useful when using query_raw_txt() which sets text transfer mode pub fn as_text(&self, idx: usize) -> Result, Error> { - if self.statement.output_format() == Format::Text { + if self.output_format == Format::Text { match self.col_buffer(idx) { Some(raw) => { FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 8743f00f0..246d36a57 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -6,7 +6,6 @@ use postgres_protocol::{ message::{backend::Field, frontend}, Oid, }; -use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -17,7 +16,6 @@ struct StatementInner { name: String, params: Vec, columns: Vec, - output_format: Format, } impl Drop for StatementInner { @@ -51,22 +49,6 @@ impl Statement { name, params, columns, - output_format: Format::Binary, - })) - } - - pub(crate) fn new_text( - inner: &Arc, - name: String, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, - params, - columns, - output_format: Format::Text, })) } @@ -83,11 +65,6 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } - - /// Returns output format for the statement. - pub fn output_format(&self) -> Format { - self.0.output_format - } } /// Information about a column of a query. diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 806196aa3..ca386974e 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -150,13 +150,14 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, query: S, params: I) -> Result + pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - self.client.query_raw_txt(query, params).await + self.client.query_raw_txt(statement, params).await } /// Like `Client::execute`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 20345c7b4..0a8aff1d7 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -314,7 +314,7 @@ async fn query_raw_txt_nulls() { async fn limit_max_backend_message_size() { let client = connect("user=postgres max_backend_message_size=10000").await; let small: Vec = client - .query_raw_txt("SELECT REPEAT('a', 20)", []) + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -325,7 +325,7 @@ async fn limit_max_backend_message_size() { assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); let large: Result, Error> = client - .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -339,7 +339,7 @@ async fn command_tag() { let client = connect("user=postgres").await; let row_stream = client - .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) .await .unwrap(); @@ -353,12 +353,40 @@ async fn command_tag() { assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); } +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'T')); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'I')); +} + #[tokio::test] async fn column_extras() { let client = connect("user=postgres").await; let rows: Vec = client - .query_raw_txt("select relacl, relname from pg_class limit 1", []) + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) .await .unwrap() .try_collect() From 78f5c02db646ae07a1cf3e6272dc3203ded443e0 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 24 Aug 2023 15:27:04 +0100 Subject: [PATCH 12/22] simple query ready for query (#22) * add ready_status on simple queries * add correct socket2 features --- postgres/src/client.rs | 4 +++- postgres/src/transaction.rs | 3 +++ tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/client.rs | 6 +++--- tokio-postgres/src/generic_client.rs | 6 ++++-- tokio-postgres/src/lib.rs | 27 +++++++++++++++++++++++++++ tokio-postgres/src/query.rs | 14 +++++++------- tokio-postgres/src/simple_query.rs | 25 +++++++++++++++++++++---- tokio-postgres/src/transaction.rs | 10 +++++----- tokio-postgres/tests/test/main.rs | 7 ++++--- 10 files changed, 78 insertions(+), 25 deletions(-) diff --git a/postgres/src/client.rs b/postgres/src/client.rs index c8e14cf81..da41e9450 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -419,7 +419,9 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.connection.block_on(self.client.batch_execute(query)) + self.connection + .block_on(self.client.batch_execute(query)) + .map(|_| ()) } /// Begins a new database transaction. diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 17c49c406..3d147cd01 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -35,6 +35,7 @@ impl<'a> Transaction<'a> { pub fn commit(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().commit()) + .map(|_| ()) } /// Rolls the transaction back, discarding all changes made within it. @@ -43,6 +44,7 @@ impl<'a> Transaction<'a> { pub fn rollback(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().rollback()) + .map(|_| ()) } /// Like `Client::prepare`. @@ -193,6 +195,7 @@ impl<'a> Transaction<'a> { pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { self.connection .block_on(self.transaction.as_ref().unwrap().batch_execute(query)) + .map(|_| ()) } /// Like `Client::cancel_token`. diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index bb58eb2d9..35450572f 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -58,6 +58,7 @@ postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } +socket2 = { version = "0.5", features = ["all"] } rand = "0.8.5" whoami = "1.4.1" diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 9a3786ec7..e763ee85f 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -15,8 +15,8 @@ use crate::types::{Oid, ToSql, Type}; use crate::Socket; use crate::{ copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, - CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, - TransactionBuilder, + CopyInSink, Error, ReadyForQueryStatus, Row, SimpleQueryMessage, Statement, ToStatement, + Transaction, TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -506,7 +506,7 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { simple_query::batch_execute(self.inner(), query).await } diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index a4ee4808b..a22318c3f 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -175,7 +175,8 @@ impl GenericClient for Client { } async fn batch_execute(&self, query: &str) -> Result<(), Error> { - self.batch_execute(query).await + self.batch_execute(query).await?; + Ok(()) } fn client(&self) -> &Client { @@ -272,7 +273,8 @@ impl GenericClient for Transaction<'_> { } async fn batch_execute(&self, query: &str) -> Result<(), Error> { - self.batch_execute(query).await + self.batch_execute(query).await?; + Ok(()) } fn client(&self) -> &Client { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 19c4f80a1..94993f361 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -118,6 +118,8 @@ //! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use postgres_protocol::message::backend::ReadyForQueryBody; + pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; @@ -142,6 +144,31 @@ pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + pub mod binary_copy; mod bind; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 7bac62fce..9fcb530eb 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Error, Portal, ReadyForQueryStatus, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; @@ -55,7 +55,7 @@ where statement, responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Binary, _p: PhantomPinned, }) @@ -109,7 +109,7 @@ where statement, responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Text, _p: PhantomPinned, }) @@ -132,7 +132,7 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, command_tag: None, - status: None, + status: ReadyForQueryStatus::Unknown, output_format: Format::Binary, _p: PhantomPinned, }) @@ -269,7 +269,7 @@ pin_project! { responses: Responses, command_tag: Option, output_format: Format, - status: Option, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } @@ -296,7 +296,7 @@ impl Stream for RowStream { } } Message::ReadyForQuery(status) => { - *this.status = Some(status.status()); + *this.status = status.into(); return Poll::Ready(None); } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), @@ -316,7 +316,7 @@ impl RowStream { /// Returns if the connection is ready for querying, with the status of the connection. /// /// This might be available only after the stream has been exhausted. - pub fn ready_status(&self) -> Option { + pub fn ready_status(&self) -> ReadyForQueryStatus { self.status } } diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index a97ee126c..0bd186cd9 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::query::extract_row_affected; -use crate::{Error, SimpleQueryMessage, SimpleQueryRow}; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; @@ -41,11 +41,15 @@ pub async fn simple_query(client: &InnerClient, query: &str) -> Result Result<(), Error> { +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { debug!("executing statement batch: {}", query); let buf = encode(client, query)?; @@ -53,7 +57,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro loop { match responses.next().await? { - Message::ReadyForQuery(_) => return Ok(()), + Message::ReadyForQuery(status) => return Ok(status.into()), Message::CommandComplete(_) | Message::EmptyQueryResponse | Message::RowDescription(_) @@ -75,11 +79,21 @@ pin_project! { pub struct SimpleQueryStream { responses: Responses, columns: Option>, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } } +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + impl Stream for SimpleQueryStream { type Item = Result; @@ -111,7 +125,10 @@ impl Stream for SimpleQueryStream { }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index ca386974e..c6a42dd7d 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -9,8 +9,8 @@ use crate::types::{BorrowToSql, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row, - SimpleQueryMessage, Statement, ToStatement, + bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, ReadyForQueryStatus, + Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::Buf; use futures_util::TryStreamExt; @@ -65,7 +65,7 @@ impl<'a> Transaction<'a> { } /// Consumes the transaction, committing all changes made within it. - pub async fn commit(mut self) -> Result<(), Error> { + pub async fn commit(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("RELEASE {}", sp.name) @@ -78,7 +78,7 @@ impl<'a> Transaction<'a> { /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. - pub async fn rollback(mut self) -> Result<(), Error> { + pub async fn rollback(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("ROLLBACK TO {}", sp.name) @@ -261,7 +261,7 @@ impl<'a> Transaction<'a> { } /// Like `Client::batch_execute`. - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { self.client.batch_execute(query).await } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0a8aff1d7..e21e2379c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -16,7 +16,8 @@ use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::types::{Kind, Type}; use tokio_postgres::{ - AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, + AsyncMessage, Client, Config, Connection, Error, IsolationLevel, ReadyForQueryStatus, + SimpleQueryMessage, }; mod binary_copy; @@ -365,7 +366,7 @@ async fn ready_for_query() { pin_mut!(row_stream); while row_stream.next().await.is_none() {} - assert_eq!(row_stream.ready_status(), Some(b'T')); + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Transaction); let row_stream = client .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) @@ -375,7 +376,7 @@ async fn ready_for_query() { pin_mut!(row_stream); while row_stream.next().await.is_none() {} - assert_eq!(row_stream.ready_status(), Some(b'I')); + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Idle); } #[tokio::test] From a23d2d5011fd3f7e9f67c5e1a07fde365ee13720 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 19 Oct 2023 10:18:35 +0100 Subject: [PATCH 13/22] fix panic in try_get --- tokio-postgres/src/row.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index 754b5f28c..bbedf0bda 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -191,7 +191,7 @@ impl Row { /// Get the raw bytes for the column at the given index. fn col_buffer(&self, idx: usize) -> Option<&[u8]> { - let range = self.ranges[idx].to_owned()?; + let range = self.ranges.get(idx)?.to_owned()?; Some(&self.body.buffer()[range]) } From 9e05e2ad0e28f4686380f1a36fc14b9bb2e69afa Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 19 Oct 2023 10:29:53 +0100 Subject: [PATCH 14/22] lints --- .github/workflows/ci.yml | 9 +++++---- postgres-types/Cargo.toml | 6 +----- postgres-types/src/eui48_04.rs | 27 --------------------------- postgres-types/src/lib.rs | 2 -- postgres/Cargo.toml | 1 - postgres/src/lib.rs | 1 - tokio-postgres/Cargo.toml | 3 +-- tokio-postgres/src/lib.rs | 1 - 8 files changed, 7 insertions(+), 43 deletions(-) delete mode 100644 postgres-types/src/eui48_04.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 937366a08..80ecc8cce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: name: rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master @@ -27,7 +27,7 @@ jobs: name: clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: sfackler/actions/rustup@master - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT id: rust-version @@ -47,13 +47,14 @@ jobs: with: path: target key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: cargo clippy --all --all-targets + - run: cargo clippy --workspace --all-targets check-wasm32: name: check-wasm32 runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - run: docker compose up -d - uses: sfackler/actions/rustup@master with: version: 1.67.0 diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index cfd083637..ac62976a3 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -16,7 +16,6 @@ array-impls = ["array-init"] with-bit-vec-0_6 = ["bit-vec-06"] with-cidr-0_2 = ["cidr-02"] with-chrono-0_4 = ["chrono-04"] -with-eui48-0_4 = ["eui48-04"] with-eui48-1 = ["eui48-1"] with-geo-types-0_6 = ["geo-types-06"] with-geo-types-0_7 = ["geo-types-0_7"] @@ -39,10 +38,7 @@ chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, "clock", ], optional = true } cidr-02 = { version = "0.2", package = "cidr", optional = true } -# eui48-04 will stop compiling and support will be removed -# See https://github.com/sfackler/rust-postgres/issues/1073 -eui48-04 = { version = "0.4", package = "eui48", optional = true } -eui48-1 = { version = "1.0", package = "eui48", optional = true, default-features = false } +eui48-1 = { version = "1.0", package = "eui48", optional = true } geo-types-06 = { version = "0.6", package = "geo-types", optional = true } geo-types-0_7 = { version = "0.7", package = "geo-types", optional = true } serde-1 = { version = "1.0", package = "serde", optional = true } diff --git a/postgres-types/src/eui48_04.rs b/postgres-types/src/eui48_04.rs deleted file mode 100644 index 45df89a84..000000000 --- a/postgres-types/src/eui48_04.rs +++ /dev/null @@ -1,27 +0,0 @@ -use bytes::BytesMut; -use eui48_04::MacAddress; -use postgres_protocol::types; -use std::error::Error; - -use crate::{FromSql, IsNull, ToSql, Type}; - -impl<'a> FromSql<'a> for MacAddress { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { - let bytes = types::macaddr_from_sql(raw)?; - Ok(MacAddress::new(bytes)) - } - - accepts!(MACADDR); -} - -impl ToSql for MacAddress { - fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { - let mut bytes = [0; 6]; - bytes.copy_from_slice(self.as_bytes()); - types::macaddr_to_sql(bytes, w); - Ok(IsNull::No) - } - - accepts!(MACADDR); - to_sql_checked!(); -} diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 4f3ae1f11..4ba711100 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -268,8 +268,6 @@ mod bit_vec_06; mod chrono_04; #[cfg(feature = "with-cidr-0_2")] mod cidr_02; -#[cfg(feature = "with-eui48-0_4")] -mod eui48_04; #[cfg(feature = "with-eui48-1")] mod eui48_1; #[cfg(feature = "with-geo-types-0_6")] diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 18406da9f..0e1b4be3e 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -24,7 +24,6 @@ circle-ci = { repository = "sfackler/rust-postgres" } array-impls = ["tokio-postgres/array-impls"] with-bit-vec-0_6 = ["tokio-postgres/with-bit-vec-0_6"] with-chrono-0_4 = ["tokio-postgres/with-chrono-0_4"] -with-eui48-0_4 = ["tokio-postgres/with-eui48-0_4"] with-eui48-1 = ["tokio-postgres/with-eui48-1"] with-geo-types-0_6 = ["tokio-postgres/with-geo-types-0_6"] with-geo-types-0_7 = ["tokio-postgres/with-geo-types-0_7"] diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index ddf1609ad..7bbe50465 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -55,7 +55,6 @@ //! | ------- | ----------- | ------------------ | ------- | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 35450572f..2f029835a 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -30,7 +30,6 @@ runtime = ["tokio/net", "tokio/time"] array-impls = ["postgres-types/array-impls"] with-bit-vec-0_6 = ["postgres-types/with-bit-vec-0_6"] with-chrono-0_4 = ["postgres-types/with-chrono-0_4"] -with-eui48-0_4 = ["postgres-types/with-eui48-0_4"] with-eui48-1 = ["postgres-types/with-eui48-1"] with-geo-types-0_6 = ["postgres-types/with-geo-types-0_6"] with-geo-types-0_7 = ["postgres-types/with-geo-types-0_7"] @@ -79,7 +78,7 @@ tokio = { version = "1.0", features = [ bit-vec-06 = { version = "0.6", package = "bit-vec" } chrono-04 = { version = "0.4", package = "chrono", default-features = false } -eui48-1 = { version = "1.0", package = "eui48", default-features = false } +eui48-1 = { version = "1.0", package = "eui48" } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } serde-1 = { version = "1.0", package = "serde" } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 94993f361..c16842840 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -107,7 +107,6 @@ //! | `array-impls` | Enables `ToSql` and `FromSql` trait impls for arrays | - | no | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | From 583d153c90782758644849c136dcca4422907a5f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 19 Oct 2023 10:43:52 +0100 Subject: [PATCH 15/22] deprecated --- tokio-postgres/tests/test/types/chrono_04.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index b010055ba..21dcd48c6 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -53,8 +53,14 @@ async fn test_with_special_naive_date_time_params() { async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( +<<<<<<< HEAD Some(Utc.from_utc_datetime( &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), +======= + Some(DateTime::from_naive_utc_and_offset( + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + Utc, +>>>>>>> 7434d938 (deprecated) )), time, ) @@ -75,8 +81,14 @@ async fn test_date_time_params() { async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( +<<<<<<< HEAD Timestamp::Value(Utc.from_utc_datetime( &NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), +======= + Timestamp::Value(DateTime::from_naive_utc_and_offset( + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + Utc, +>>>>>>> 7434d938 (deprecated) )), time, ) From 65d1df70c88f340f8871774a7b86003606d0ac93 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 31 Oct 2023 15:04:16 +0000 Subject: [PATCH 16/22] make raw_txt not prepare statements --- tokio-postgres/src/client.rs | 6 +-- tokio-postgres/src/generic_client.rs | 13 ++---- tokio-postgres/src/query.rs | 60 +++++++++++++++++++++++++--- tokio-postgres/src/statement.rs | 9 +++++ tokio-postgres/src/transaction.rs | 3 +- 5 files changed, 69 insertions(+), 22 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e763ee85f..606865b79 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -372,18 +372,16 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt<'a, T, S, I>( + pub async fn query_raw_txt( &self, - statement: &T, + statement: &str, params: I, ) -> Result where - T: ?Sized + ToStatement, S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - let statement = statement.__convert().into_statement(self).await?; query::query_txt(&self.inner, statement, params).await } diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index a22318c3f..f68c3aa71 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -57,13 +57,8 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator; /// Like `Client::query_raw_txt`. - async fn query_raw_txt<'a, T, S, I>( - &self, - statement: &T, - params: I, - ) -> Result + async fn query_raw_txt(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; @@ -148,9 +143,8 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + async fn query_raw_txt(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, @@ -245,9 +239,8 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + async fn query_raw_txt(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 9fcb530eb..18b3f326a 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,14 +2,15 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, ReadyForQueryStatus, Row, Statement}; +use crate::{Column, Error, Portal, ReadyForQueryStatus, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; -use postgres_types::Format; +use postgres_types::{Format, Type}; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; @@ -63,7 +64,7 @@ where pub async fn query_txt( client: &Arc, - statement: Statement, + query: &str, params: I, ) -> Result where @@ -74,10 +75,18 @@ where let params = params.into_iter(); let buf = client.with_buf(|buf| { + frontend::parse( + "", // unnamed prepared statement + query, // query to parse + std::iter::empty(), // give no type info + buf, + ) + .map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; // Bind, pass params as text, retrieve as binary match frontend::bind( "", // empty string selects the unnamed portal - statement.name(), // named prepared statement + "", // unnamed prepared statement std::iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { @@ -104,9 +113,48 @@ where })?; // now read the responses - let responses = start(client, buf).await?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + Ok(RowStream { - statement, + statement: Statement::new_anonymous(parameters, columns), responses, command_tag: None, status: ReadyForQueryStatus::Unknown, diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 246d36a57..b2f4c87cb 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -52,6 +52,15 @@ impl Statement { })) } + pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { &self.0.name } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index c6a42dd7d..4e071172a 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -150,9 +150,8 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement, S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, From c849e8a373ece0f4b2212563b098e3c98fdaa9dc Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 31 Oct 2023 15:29:22 +0000 Subject: [PATCH 17/22] fmt --- tokio-postgres/src/client.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 606865b79..f251920c7 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -372,11 +372,7 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt( - &self, - statement: &str, - params: I, - ) -> Result + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result where S: AsRef, I: IntoIterator>, From 0582732a241eebf4a5a3ef908a9349ef801e5ecd Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 31 Oct 2023 16:23:01 +0000 Subject: [PATCH 18/22] offer get_type api --- tokio-postgres/src/client.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index f251920c7..5242e98c2 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -602,6 +602,11 @@ impl Client { self.inner().clear_type_cache(); } + /// Query for type information + pub async fn get_type(&self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, oid).await + } + /// Determines if the connection to the server has already closed. /// /// In that case, all future queries will fail. From 5316117ed484e59553ad8af50b26b1abbe0acaeb Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 31 Oct 2023 16:28:30 +0000 Subject: [PATCH 19/22] add columns to rowstream --- tokio-postgres/src/query.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 18b3f326a..0677ba648 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -354,6 +354,11 @@ impl Stream for RowStream { } impl RowStream { + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + /// Returns the command tag of this query. /// /// This is only available after the stream has been exhausted. From 693cf06d9d68b63dd1b136871e0d0ccbe5995103 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 31 Oct 2023 16:33:51 +0000 Subject: [PATCH 20/22] get type generic client --- tokio-postgres/src/generic_client.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index f68c3aa71..386749711 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -2,6 +2,7 @@ use crate::query::RowStream; use crate::types::{BorrowToSql, ToSql, Type}; use crate::{Client, Error, Row, Statement, ToStatement, Transaction}; use async_trait::async_trait; +use postgres_protocol::Oid; mod private { pub trait Sealed {} @@ -79,6 +80,9 @@ pub trait GenericClient: private::Sealed { /// Like `Client::batch_execute`. async fn batch_execute(&self, query: &str) -> Result<(), Error>; + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result; + /// Returns a reference to the underlying `Client`. fn client(&self) -> &Client; } @@ -173,6 +177,11 @@ impl GenericClient for Client { Ok(()) } + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.get_type(oid).await + } + fn client(&self) -> &Client { self } @@ -270,6 +279,11 @@ impl GenericClient for Transaction<'_> { Ok(()) } + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.client().get_type(oid).await + } + fn client(&self) -> &Client { self.client() } From 16489a955eeab9cf0fba470078bb7cc64ecd6578 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 3 Nov 2023 11:28:13 +0100 Subject: [PATCH 21/22] make CopyBothDuplex struct `pub` (#25) This is useful / needed to build a Rust client for the Pageserver's GetPage@LSN API, which uses CopyBoth mode. --- tokio-postgres/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index c16842840..120d9718c 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -123,6 +123,7 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; +pub use crate::copy_both::CopyBothDuplex; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError; From fc7ae2393b776e4897a442189fc86e805d4589a8 Mon Sep 17 00:00:00 2001 From: khanova <32508607+khanova@users.noreply.github.com> Date: Thu, 16 Nov 2023 20:53:47 +0100 Subject: [PATCH 22/22] Getter for process id (#26) Added getter for process_id --- tokio-postgres/Cargo.toml | 1 - tokio-postgres/src/client.rs | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 2f029835a..407cb91a0 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -57,7 +57,6 @@ postgres-protocol = { version = "0.6.6", path = "../postgres-protocol" } postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } -socket2 = { version = "0.5", features = ["all"] } rand = "0.8.5" whoami = "1.4.1" diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 5242e98c2..eb0d44824 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -207,6 +207,11 @@ impl Client { } } + /// Returns process_id. + pub fn get_process_id(&self) -> i32 { + self.process_id + } + pub(crate) fn inner(&self) -> &Arc { &self.inner }