Skip to content

Commit e71898f

Browse files
committed
feat(derive): add #[postgres(allow_mismatch)]
1 parent 8b9b5d0 commit e71898f

File tree

8 files changed

+239
-7
lines changed

8 files changed

+239
-7
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use postgres_types::{FromSql, ToSql};
2+
3+
#[derive(ToSql, Debug)]
4+
#[postgres(allow_mismatch)]
5+
struct ToSqlAllowMismatchStruct {
6+
a: i32,
7+
}
8+
9+
#[derive(FromSql, Debug)]
10+
#[postgres(allow_mismatch)]
11+
struct FromSqlAllowMismatchStruct {
12+
a: i32,
13+
}
14+
15+
#[derive(ToSql, Debug)]
16+
#[postgres(allow_mismatch)]
17+
struct ToSqlAllowMismatchTupleStruct(i32, i32);
18+
19+
#[derive(FromSql, Debug)]
20+
#[postgres(allow_mismatch)]
21+
struct FromSqlAllowMismatchTupleStruct(i32, i32);
22+
23+
#[derive(FromSql, Debug)]
24+
#[postgres(transparent, allow_mismatch)]
25+
struct TransparentFromSqlAllowMismatchStruct(i32);
26+
27+
#[derive(FromSql, Debug)]
28+
#[postgres(allow_mismatch, transparent)]
29+
struct AllowMismatchFromSqlTransparentStruct(i32);
30+
31+
fn main() {}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
error: #[postgres(allow_mismatch)] may only be applied to enums
2+
--> src/compile-fail/invalid-allow-mismatch.rs:4:1
3+
|
4+
4 | / #[postgres(allow_mismatch)]
5+
5 | | struct ToSqlAllowMismatchStruct {
6+
6 | | a: i32,
7+
7 | | }
8+
| |_^
9+
10+
error: #[postgres(allow_mismatch)] may only be applied to enums
11+
--> src/compile-fail/invalid-allow-mismatch.rs:10:1
12+
|
13+
10 | / #[postgres(allow_mismatch)]
14+
11 | | struct FromSqlAllowMismatchStruct {
15+
12 | | a: i32,
16+
13 | | }
17+
| |_^
18+
19+
error: #[postgres(allow_mismatch)] may only be applied to enums
20+
--> src/compile-fail/invalid-allow-mismatch.rs:16:1
21+
|
22+
16 | / #[postgres(allow_mismatch)]
23+
17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32);
24+
| |_______________________________________________^
25+
26+
error: #[postgres(allow_mismatch)] may only be applied to enums
27+
--> src/compile-fail/invalid-allow-mismatch.rs:20:1
28+
|
29+
20 | / #[postgres(allow_mismatch)]
30+
21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32);
31+
| |_________________________________________________^
32+
33+
error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]
34+
--> src/compile-fail/invalid-allow-mismatch.rs:24:25
35+
|
36+
24 | #[postgres(transparent, allow_mismatch)]
37+
| ^^^^^^^^^^^^^^
38+
39+
error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]
40+
--> src/compile-fail/invalid-allow-mismatch.rs:28:28
41+
|
42+
28 | #[postgres(allow_mismatch, transparent)]
43+
| ^^^^^^^^^^^

postgres-derive-test/src/enums.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::test_type;
2-
use postgres::{Client, NoTls};
2+
use postgres::{error::DbError, Client, NoTls};
33
use postgres_types::{FromSql, ToSql, WrongType};
44
use std::error::Error;
55

@@ -102,3 +102,73 @@ fn missing_variant() {
102102
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
103103
assert!(err.source().unwrap().is::<WrongType>());
104104
}
105+
106+
#[test]
107+
fn allow_mismatch_enums() {
108+
#[derive(Debug, ToSql, FromSql, PartialEq)]
109+
#[postgres(allow_mismatch)]
110+
enum Foo {
111+
Bar,
112+
}
113+
114+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
115+
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
116+
.unwrap();
117+
118+
let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap();
119+
assert_eq!(row.get::<_, Foo>(0), Foo::Bar);
120+
}
121+
122+
#[test]
123+
fn missing_enum_variant() {
124+
#[derive(Debug, ToSql, FromSql, PartialEq)]
125+
#[postgres(allow_mismatch)]
126+
enum Foo {
127+
Bar,
128+
Buz,
129+
}
130+
131+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
132+
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
133+
.unwrap();
134+
135+
let err = conn
136+
.query_one("SELECT $1::\"Foo\"", &[&Foo::Buz])
137+
.unwrap_err();
138+
assert!(err.source().unwrap().is::<DbError>());
139+
}
140+
141+
#[test]
142+
fn allow_mismatch_and_renaming() {
143+
#[derive(Debug, ToSql, FromSql, PartialEq)]
144+
#[postgres(name = "foo", allow_mismatch)]
145+
enum Foo {
146+
#[postgres(name = "bar")]
147+
Bar,
148+
#[postgres(name = "buz")]
149+
Buz,
150+
}
151+
152+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
153+
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[])
154+
.unwrap();
155+
156+
let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap();
157+
assert_eq!(row.get::<_, Foo>(0), Foo::Buz);
158+
}
159+
160+
#[test]
161+
fn wrong_name_and_allow_mismatch() {
162+
#[derive(Debug, ToSql, FromSql, PartialEq)]
163+
#[postgres(allow_mismatch)]
164+
enum Foo {
165+
Bar,
166+
}
167+
168+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
169+
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
170+
.unwrap();
171+
172+
let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
173+
assert!(err.source().unwrap().is::<WrongType>());
174+
}

postgres-derive/src/accepts.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
3131
}
3232
}
3333

34-
pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
34+
pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream {
3535
let num_variants = variants.len();
3636
let variant_names = variants.iter().map(|v| &v.name);
3737

@@ -40,6 +40,10 @@ pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
4040
return false;
4141
}
4242

43+
if #allow_mismatch {
44+
return true;
45+
}
46+
4347
match *type_.kind() {
4448
::postgres_types::Kind::Enum(ref variants) => {
4549
if variants.len() != #num_variants {

postgres-derive/src/fromsql.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
4545
))
4646
}
4747
}
48+
} else if overrides.allow_mismatch {
49+
match input.data {
50+
Data::Enum(ref data) => {
51+
let variants = data
52+
.variants
53+
.iter()
54+
.map(Variant::parse)
55+
.collect::<Result<Vec<_>, _>>()?;
56+
(
57+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
58+
enum_body(&input.ident, &variants),
59+
)
60+
}
61+
_ => {
62+
return Err(Error::new_spanned(
63+
input,
64+
"#[postgres(allow_mismatch)] may only be applied to enums",
65+
));
66+
}
67+
}
4868
} else {
4969
match input.data {
5070
Data::Enum(ref data) => {
@@ -54,7 +74,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
5474
.map(Variant::parse)
5575
.collect::<Result<Vec<_>, _>>()?;
5676
(
57-
accepts::enum_body(&name, &variants),
77+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
5878
enum_body(&input.ident, &variants),
5979
)
6080
}

postgres-derive/src/overrides.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token};
44
pub struct Overrides {
55
pub name: Option<String>,
66
pub transparent: bool,
7+
pub allow_mismatch: bool,
78
}
89

910
impl Overrides {
1011
pub fn extract(attrs: &[Attribute]) -> Result<Overrides, Error> {
1112
let mut overrides = Overrides {
1213
name: None,
1314
transparent: false,
15+
allow_mismatch: false,
1416
};
1517

1618
for attr in attrs {
@@ -44,11 +46,25 @@ impl Overrides {
4446
overrides.name = Some(value);
4547
}
4648
Meta::Path(path) => {
47-
if !path.is_ident("transparent") {
49+
if path.is_ident("transparent") {
50+
if overrides.allow_mismatch {
51+
return Err(Error::new_spanned(
52+
path,
53+
"#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]",
54+
));
55+
}
56+
overrides.transparent = true;
57+
} else if path.is_ident("allow_mismatch") {
58+
if overrides.transparent {
59+
return Err(Error::new_spanned(
60+
path,
61+
"#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]",
62+
));
63+
}
64+
overrides.allow_mismatch = true;
65+
} else {
4866
return Err(Error::new_spanned(path, "unknown override"));
4967
}
50-
51-
overrides.transparent = true;
5268
}
5369
bad => return Err(Error::new_spanned(bad, "unknown attribute")),
5470
}

postgres-derive/src/tosql.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
4141
));
4242
}
4343
}
44+
} else if overrides.allow_mismatch {
45+
match input.data {
46+
Data::Enum(ref data) => {
47+
let variants = data
48+
.variants
49+
.iter()
50+
.map(Variant::parse)
51+
.collect::<Result<Vec<_>, _>>()?;
52+
(
53+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
54+
enum_body(&input.ident, &variants),
55+
)
56+
}
57+
_ => {
58+
return Err(Error::new_spanned(
59+
input,
60+
"#[postgres(allow_mismatch)] may only be applied to enums",
61+
));
62+
}
63+
}
4464
} else {
4565
match input.data {
4666
Data::Enum(ref data) => {
@@ -50,7 +70,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
5070
.map(Variant::parse)
5171
.collect::<Result<Vec<_>, _>>()?;
5272
(
53-
accepts::enum_body(&name, &variants),
73+
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
5474
enum_body(&input.ident, &variants),
5575
)
5676
}

postgres-types/src/lib.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,34 @@
125125
//! Happy,
126126
//! }
127127
//! ```
128+
//!
129+
//! ## Allowing Enum Mismatches
130+
//!
131+
//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum
132+
//! variants between the Rust and Postgres types.
133+
//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition:
134+
//!
135+
//! ```sql
136+
//! CREATE TYPE mood AS ENUM (
137+
//! 'Sad',
138+
//! 'Ok',
139+
//! 'Happy'
140+
//! );
141+
//! ```
142+
//!
143+
//! ```rust
144+
//! # #[cfg(feature = "derive")]
145+
//! use postgres_types::{ToSql, FromSql};
146+
//!
147+
//! # #[cfg(feature = "derive")]
148+
//! #[derive(Debug, ToSql, FromSql)]
149+
//! #[postgres(allow_mismatch)]
150+
//! enum Mood {
151+
//! Sad,
152+
//! Happy,
153+
//! Meh,
154+
//! }
155+
//! ```
128156
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
129157
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
130158

0 commit comments

Comments
 (0)