diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 422dcbda9..0315ac805 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -96,4 +96,5 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL CREATE ROLE ssl_user LOGIN; CREATE EXTENSION hstore; CREATE EXTENSION citext; + CREATE EXTENSION ltree; EOSQL diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 2010e88ad..a4716907b 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-protocol" -version = "0.6.3" +version = "0.6.4" authors = ["Steven Fackler "] edition = "2018" description = "Low level Postgres protocol APIs" diff --git a/postgres-protocol/src/types/mod.rs b/postgres-protocol/src/types/mod.rs index a595f5a30..05f515f76 100644 --- a/postgres-protocol/src/types/mod.rs +++ b/postgres-protocol/src/types/mod.rs @@ -1059,3 +1059,60 @@ impl Inet { self.netmask } } + +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } +} diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 7c20cf3ed..6f1851fc2 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -1,4 +1,4 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use std::collections::HashMap; @@ -156,3 +156,87 @@ fn non_null_array() { assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); assert_eq!(array.values().collect::>().unwrap(), values); } + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) +} diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 9d470f37b..000d71ea0 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "postgres-types" -version = "0.2.2" +version = "0.2.3" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -28,7 +28,7 @@ with-time-0_3 = ["time-03"] [dependencies] bytes = "1.0" fallible-iterator = "0.2" -postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" } +postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-derive = { version = "0.4.0", optional = true, path = "../postgres-derive" } array-init = { version = "2", optional = true } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 394f938ff..d029d3948 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -407,6 +407,7 @@ impl WrongType { /// | `f32` | REAL | /// | `f64` | DOUBLE PRECISION | /// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME, UNKNOWN | +/// | | LTREE, LQUERY, LTXTQUERY | /// | `&[u8]`/`Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | @@ -594,8 +595,8 @@ impl<'a> FromSql<'a> for &'a [u8] { } impl<'a> FromSql<'a> for String { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { - types::text_from_sql(raw).map(ToString::to_string) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) } fn accepts(ty: &Type) -> bool { @@ -604,8 +605,8 @@ impl<'a> FromSql<'a> for String { } impl<'a> FromSql<'a> for Box { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result, Box> { - types::text_from_sql(raw) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + <&str as FromSql>::from_sql(ty, raw) .map(ToString::to_string) .map(String::into_boxed_str) } @@ -616,14 +617,26 @@ impl<'a> FromSql<'a> for Box { } impl<'a> FromSql<'a> for &'a str { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { - types::text_from_sql(raw) + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } } fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } @@ -727,6 +740,7 @@ pub enum IsNull { /// | `f32` | REAL | /// | `f64` | DOUBLE PRECISION | /// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME | +/// | | LTREE, LQUERY, LTXTQUERY | /// | `&[u8]`/`Vec` | BYTEA | /// | `HashMap>` | HSTORE | /// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | @@ -924,15 +938,27 @@ impl ToSql for Vec { } impl<'a> ToSql for &'a str { - fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { - types::text_to_sql(*self, w); + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(*self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(*self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(*self, w), + _ => types::text_to_sql(*self, w), + } Ok(IsNull::No) } fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 94371af51..82e71fb1c 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-postgres" -version = "0.7.5" +version = "0.7.6" authors = ["Steven Fackler "] edition = "2018" license = "MIT/Apache-2.0" @@ -50,8 +50,8 @@ parking_lot = "0.12" percent-encoding = "2.0" pin-project-lite = "0.2" phf = "0.10" -postgres-protocol = { version = "0.6.1", path = "../postgres-protocol" } -postgres-types = { version = "0.2.2", path = "../postgres-types" } +postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } +postgres-types = { version = "0.2.3", path = "../postgres-types" } socket2 = "0.4" tokio = { version = "1.0", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 604e2de32..de700d791 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -648,3 +648,90 @@ async fn inet() { ) .await; } + +#[tokio::test] +async fn ltree() { + test_type( + "ltree", + &[(Some("b.c.d".to_owned()), "'b.c.d'"), (None, "NULL")], + ) + .await; +} + +#[tokio::test] +async fn ltree_any() { + test_type( + "ltree[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["a.b.c".to_string()]), "ARRAY['a.b.c']"), + ( + Some(vec!["a.b.c".to_string(), "e.f.g".to_string()]), + "ARRAY['a.b.c','e.f.g']", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn lquery() { + test_type( + "lquery", + &[ + (Some("b.c.d".to_owned()), "'b.c.d'"), + (Some("b.c.*".to_owned()), "'b.c.*'"), + (Some("b.*{1,2}.d|e".to_owned()), "'b.*{1,2}.d|e'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn lquery_any() { + test_type( + "lquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b.c.*".to_string()]), "ARRAY['b.c.*']"), + ( + Some(vec!["b.c.*".to_string(), "b.*{1,2}.d|e".to_string()]), + "ARRAY['b.c.*','b.*{1,2}.d|e']", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn ltxtquery() { + test_type( + "ltxtquery", + &[ + (Some("b & c & d".to_owned()), "'b & c & d'"), + (Some("b@* & !c".to_owned()), "'b@* & !c'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn ltxtquery_any() { + test_type( + "ltxtquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b & c & d".to_string()]), "ARRAY['b & c & d']"), + ( + Some(vec!["b & c & d".to_string(), "b@* & !c".to_string()]), + "ARRAY['b & c & d','b@* & !c']", + ), + (None, "NULL"), + ], + ) + .await; +}