Skip to content

Commit 9616c36

Browse files
committed
Parse signed/unsigned integer data types correctly in MySQL CAST
MySQL doesn't have the same set of possible CAST types as for e.g. column definitions. For example, it raises a syntax error for `CAST(1 AS INTEGER SIGNED)` and instead expects `CAST(1 AS SIGNED INTEGER)`. This patch takes a somewhat unfortunate route of 1) adding a boolean flag that modifies the `parse_data_type` match expression, and 2) storing whether it was parsed this way in the `Expr::Cast` AST node to be used during formatting. Feedback and ideas for alternative approaches are very welcome. Closes #1589
1 parent 3ace97c commit 9616c36

File tree

8 files changed

+150
-64
lines changed

8 files changed

+150
-64
lines changed

src/ast/mod.rs

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -798,9 +798,15 @@ pub enum Expr {
798798
kind: CastKind,
799799
expr: Box<Expr>,
800800
data_type: DataType,
801-
// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by BigQuery
802-
// https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
801+
/// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by [BigQuery]
802+
///
803+
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
803804
format: Option<CastFormat>,
805+
/// Whether this was parsed as a [MySQL]-style cast, which has a different syntax for
806+
/// `SIGNED/UNSIGNED INTEGER` casts.
807+
///
808+
/// [MySQL]: https://dev.mysql.com/doc/refman/8.4/en/cast-functions.html#function_cast
809+
mysql_style_int: bool,
804810
},
805811
/// AT a timestamp to a different timezone e.g. `FROM_UNIXTIME(0) AT TIME ZONE 'UTC-06:00'`
806812
AtTimeZone {
@@ -1560,32 +1566,42 @@ impl fmt::Display for Expr {
15601566
expr,
15611567
data_type,
15621568
format,
1563-
} => match kind {
1564-
CastKind::Cast => {
1565-
if let Some(format) = format {
1566-
write!(f, "CAST({expr} AS {data_type} FORMAT {format})")
1567-
} else {
1568-
write!(f, "CAST({expr} AS {data_type})")
1569+
mysql_style_int,
1570+
} => {
1571+
let data_type = if *mysql_style_int && data_type == &DataType::BigInt(None) {
1572+
"SIGNED".to_string()
1573+
} else if *mysql_style_int && data_type == &DataType::UnsignedBigInt(None) {
1574+
"UNSIGNED".to_string()
1575+
} else {
1576+
data_type.to_string()
1577+
};
1578+
match kind {
1579+
CastKind::Cast => {
1580+
if let Some(format) = format {
1581+
write!(f, "CAST({expr} AS {data_type} FORMAT {format})")
1582+
} else {
1583+
write!(f, "CAST({expr} AS {data_type})")
1584+
}
15691585
}
1570-
}
1571-
CastKind::TryCast => {
1572-
if let Some(format) = format {
1573-
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
1574-
} else {
1575-
write!(f, "TRY_CAST({expr} AS {data_type})")
1586+
CastKind::TryCast => {
1587+
if let Some(format) = format {
1588+
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
1589+
} else {
1590+
write!(f, "TRY_CAST({expr} AS {data_type})")
1591+
}
15761592
}
1577-
}
1578-
CastKind::SafeCast => {
1579-
if let Some(format) = format {
1580-
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
1581-
} else {
1582-
write!(f, "SAFE_CAST({expr} AS {data_type})")
1593+
CastKind::SafeCast => {
1594+
if let Some(format) = format {
1595+
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
1596+
} else {
1597+
write!(f, "SAFE_CAST({expr} AS {data_type})")
1598+
}
1599+
}
1600+
CastKind::DoubleColon => {
1601+
write!(f, "{expr}::{data_type}")
15831602
}
15841603
}
1585-
CastKind::DoubleColon => {
1586-
write!(f, "{expr}::{data_type}")
1587-
}
1588-
},
1604+
}
15891605
Expr::Extract {
15901606
field,
15911607
syntax,

src/ast/spans.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,7 @@ impl Spanned for Expr {
14071407
expr,
14081408
data_type: _,
14091409
format: _,
1410+
mysql_style_int: _,
14101411
} => expr.span(),
14111412
Expr::AtTimeZone {
14121413
timestamp,

src/keywords.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ define_keywords!(
790790
SHARE,
791791
SHARING,
792792
SHOW,
793+
SIGNED,
793794
SIMILAR,
794795
SKIP,
795796
SLOW,

src/parser/mod.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,14 +2175,19 @@ impl<'a> Parser<'a> {
21752175
self.expect_token(&Token::LParen)?;
21762176
let expr = self.parse_expr()?;
21772177
self.expect_keyword_is(Keyword::AS)?;
2178-
let data_type = self.parse_data_type()?;
2178+
let dialect = self.dialect;
2179+
let mysql_style_int = dialect_is!(dialect is MySqlDialect);
2180+
let data_type = self.parse_data_type_for_cast(mysql_style_int)?;
2181+
let mysql_style_int = mysql_style_int
2182+
&& (data_type == DataType::BigInt(None) || data_type == DataType::UnsignedBigInt(None));
21792183
let format = self.parse_optional_cast_format()?;
21802184
self.expect_token(&Token::RParen)?;
21812185
Ok(Expr::Cast {
21822186
kind,
21832187
expr: Box::new(expr),
21842188
data_type,
21852189
format,
2190+
mysql_style_int,
21862191
})
21872192
}
21882193

@@ -2884,7 +2889,7 @@ impl<'a> Parser<'a> {
28842889
Some(self.parse_identifier()?)
28852890
};
28862891

2887-
let (field_type, trailing_bracket) = self.parse_data_type_helper()?;
2892+
let (field_type, trailing_bracket) = self.parse_data_type_helper(false)?;
28882893

28892894
Ok((
28902895
StructField {
@@ -3401,6 +3406,7 @@ impl<'a> Parser<'a> {
34013406
expr: Box::new(expr),
34023407
data_type: self.parse_data_type()?,
34033408
format: None,
3409+
mysql_style_int: false,
34043410
})
34053411
} else if Token::ExclamationMark == *tok && self.dialect.supports_factorial_operator() {
34063412
Ok(Expr::UnaryOp {
@@ -3641,6 +3647,7 @@ impl<'a> Parser<'a> {
36413647
expr: Box::new(expr),
36423648
data_type: self.parse_data_type()?,
36433649
format: None,
3650+
mysql_style_int: false,
36443651
})
36453652
}
36463653

@@ -8793,7 +8800,16 @@ impl<'a> Parser<'a> {
87938800

87948801
/// Parse a SQL datatype (in the context of a CREATE TABLE statement for example)
87958802
pub fn parse_data_type(&mut self) -> Result<DataType, ParserError> {
8796-
let (ty, trailing_bracket) = self.parse_data_type_helper()?;
8803+
self.parse_data_type_for_cast(false)
8804+
}
8805+
8806+
/// Parse a SQL datatype, possibly in the context of a CAST expression, for which some dialects
8807+
/// (i.e. MySQL) have different syntax. See [`Expr::Cast::mysql_style_int`]
8808+
pub fn parse_data_type_for_cast(
8809+
&mut self,
8810+
mysql_style_int: bool,
8811+
) -> Result<DataType, ParserError> {
8812+
let (ty, trailing_bracket) = self.parse_data_type_helper(mysql_style_int)?;
87978813
if trailing_bracket.0 {
87988814
return parser_err!(
87998815
format!("unmatched > after parsing data type {ty}"),
@@ -8806,6 +8822,7 @@ impl<'a> Parser<'a> {
88068822

88078823
fn parse_data_type_helper(
88088824
&mut self,
8825+
mysql_style_int: bool,
88098826
) -> Result<(DataType, MatchedTrailingBracket), ParserError> {
88108827
let dialect = self.dialect;
88118828
self.advance_token();
@@ -8832,7 +8849,7 @@ impl<'a> Parser<'a> {
88328849
))
88338850
}
88348851
}
8835-
Keyword::TINYINT => {
8852+
Keyword::TINYINT if !mysql_style_int => {
88368853
let optional_precision = self.parse_optional_precision();
88378854
if self.parse_keyword(Keyword::UNSIGNED) {
88388855
Ok(DataType::UnsignedTinyInt(optional_precision?))
@@ -8848,23 +8865,23 @@ impl<'a> Parser<'a> {
88488865
Ok(DataType::Int2(optional_precision?))
88498866
}
88508867
}
8851-
Keyword::SMALLINT => {
8868+
Keyword::SMALLINT if !mysql_style_int => {
88528869
let optional_precision = self.parse_optional_precision();
88538870
if self.parse_keyword(Keyword::UNSIGNED) {
88548871
Ok(DataType::UnsignedSmallInt(optional_precision?))
88558872
} else {
88568873
Ok(DataType::SmallInt(optional_precision?))
88578874
}
88588875
}
8859-
Keyword::MEDIUMINT => {
8876+
Keyword::MEDIUMINT if !mysql_style_int => {
88608877
let optional_precision = self.parse_optional_precision();
88618878
if self.parse_keyword(Keyword::UNSIGNED) {
88628879
Ok(DataType::UnsignedMediumInt(optional_precision?))
88638880
} else {
88648881
Ok(DataType::MediumInt(optional_precision?))
88658882
}
88668883
}
8867-
Keyword::INT => {
8884+
Keyword::INT if !mysql_style_int => {
88688885
let optional_precision = self.parse_optional_precision();
88698886
if self.parse_keyword(Keyword::UNSIGNED) {
88708887
Ok(DataType::UnsignedInt(optional_precision?))
@@ -8893,15 +8910,15 @@ impl<'a> Parser<'a> {
88938910
Keyword::INT64 => Ok(DataType::Int64),
88948911
Keyword::INT128 => Ok(DataType::Int128),
88958912
Keyword::INT256 => Ok(DataType::Int256),
8896-
Keyword::INTEGER => {
8913+
Keyword::INTEGER if !mysql_style_int => {
88978914
let optional_precision = self.parse_optional_precision();
88988915
if self.parse_keyword(Keyword::UNSIGNED) {
88998916
Ok(DataType::UnsignedInteger(optional_precision?))
89008917
} else {
89018918
Ok(DataType::Integer(optional_precision?))
89028919
}
89038920
}
8904-
Keyword::BIGINT => {
8921+
Keyword::BIGINT if !mysql_style_int => {
89058922
let optional_precision = self.parse_optional_precision();
89068923
if self.parse_keyword(Keyword::UNSIGNED) {
89078924
Ok(DataType::UnsignedBigInt(optional_precision?))
@@ -9049,7 +9066,8 @@ impl<'a> Parser<'a> {
90499066
})?)
90509067
} else {
90519068
self.expect_token(&Token::Lt)?;
9052-
let (inside_type, _trailing_bracket) = self.parse_data_type_helper()?;
9069+
let (inside_type, _trailing_bracket) =
9070+
self.parse_data_type_helper(mysql_style_int)?;
90539071
trailing_bracket = self.expect_closing_angle_bracket(_trailing_bracket)?;
90549072
Ok(DataType::Array(ArrayElemTypeDef::AngleBracket(Box::new(
90559073
inside_type,
@@ -9110,6 +9128,14 @@ impl<'a> Parser<'a> {
91109128
let columns = self.parse_returns_table_columns()?;
91119129
Ok(DataType::Table(columns))
91129130
}
9131+
Keyword::UNSIGNED if mysql_style_int => {
9132+
let _integer = self.parse_keyword(Keyword::INTEGER);
9133+
Ok(DataType::UnsignedBigInt(None))
9134+
}
9135+
Keyword::SIGNED if mysql_style_int => {
9136+
let _integer = self.parse_keyword(Keyword::INTEGER);
9137+
Ok(DataType::BigInt(None))
9138+
}
91139139
_ => {
91149140
self.prev_token();
91159141
let type_name = self.parse_object_name(false)?;

0 commit comments

Comments
 (0)