diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 94951693c..5f8eec8bc 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -201,6 +201,12 @@ pub enum Expr { expr: Box, data_type: DataType, }, + /// TRY_CAST an expression to a different data type e.g. `TRY_CAST(foo AS VARCHAR(123))` + // this differs from CAST in the choice of how to implement invalid conversions + TryCast { + expr: Box, + data_type: DataType, + }, /// EXTRACT(DateTimeField FROM ) Extract { field: DateTimeField, @@ -309,6 +315,7 @@ impl fmt::Display for Expr { } } Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type), + Expr::TryCast { expr, data_type } => write!(f, "TRY_CAST({} AS {})", expr, data_type), Expr::Extract { field, expr } => write!(f, "EXTRACT({} FROM {})", field, expr), Expr::Collate { expr, collation } => write!(f, "{} COLLATE {}", expr, collation), Expr::Nested(ast) => write!(f, "({})", ast), diff --git a/src/dialect/keywords.rs b/src/dialect/keywords.rs index 1d2690fc0..3371ff570 100644 --- a/src/dialect/keywords.rs +++ b/src/dialect/keywords.rs @@ -456,6 +456,7 @@ define_keywords!( TRIM_ARRAY, TRUE, TRUNCATE, + TRY_CAST, UESCAPE, UNBOUNDED, UNCOMMITTED, diff --git a/src/parser.rs b/src/parser.rs index eab2ece12..bacae7873 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -352,6 +352,7 @@ impl<'a> Parser<'a> { } Keyword::CASE => self.parse_case_expr(), Keyword::CAST => self.parse_cast_expr(), + Keyword::TRY_CAST => self.parse_try_cast_expr(), Keyword::EXISTS => self.parse_exists_expr(), Keyword::EXTRACT => self.parse_extract_expr(), Keyword::SUBSTRING => self.parse_substring_expr(), @@ -591,6 +592,19 @@ impl<'a> Parser<'a> { }) } + /// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)` + pub fn parse_try_cast_expr(&mut self) -> Result { + self.expect_token(&Token::LParen)?; + let expr = self.parse_expr()?; + self.expect_keyword(Keyword::AS)?; + let data_type = self.parse_data_type()?; + self.expect_token(&Token::RParen)?; + Ok(Expr::TryCast { + expr: Box::new(expr), + data_type, + }) + } + /// Parse a SQL EXISTS expression e.g. `WHERE EXISTS(SELECT ...)`. pub fn parse_exists_expr(&mut self) -> Result { self.expect_token(&Token::LParen)?; @@ -1806,7 +1820,7 @@ impl<'a> Parser<'a> { let columns = self.parse_parenthesized_column_list(Optional)?; self.expect_keywords(&[Keyword::FROM, Keyword::STDIN])?; self.expect_token(&Token::SemiColon)?; - let values = self.parse_tsv()?; + let values = self.parse_tsv(); Ok(Statement::Copy { table_name, columns, @@ -1816,12 +1830,11 @@ impl<'a> Parser<'a> { /// Parse a tab separated values in /// COPY payload - fn parse_tsv(&mut self) -> Result>, ParserError> { - let values = self.parse_tab_value()?; - Ok(values) + fn parse_tsv(&mut self) -> Vec> { + self.parse_tab_value() } - fn parse_tab_value(&mut self) -> Result>, ParserError> { + fn parse_tab_value(&mut self) -> Vec> { let mut values = vec![]; let mut content = String::from(""); while let Some(t) = self.next_token_no_skip() { @@ -1836,7 +1849,7 @@ impl<'a> Parser<'a> { } Token::Backslash => { if self.consume_token(&Token::Period) { - return Ok(values); + return values; } if let Token::Word(w) = self.next_token() { if w.value == "N" { @@ -1849,7 +1862,7 @@ impl<'a> Parser<'a> { } } } - Ok(values) + values } /// Parse a literal value (numbers, strings, date/time, booleans) diff --git a/src/tokenizer.rs b/src/tokenizer.rs index d82810528..ce48fa018 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -626,6 +626,7 @@ impl<'a> Tokenizer<'a> { } } + #[allow(clippy::unnecessary_wraps)] fn consume_and_return( &self, chars: &mut Peekable>, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index fbf2faf9b..10ac79d84 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -981,6 +981,35 @@ fn parse_cast() { ); } +#[test] +fn parse_try_cast() { + let sql = "SELECT TRY_CAST(id AS BIGINT) FROM customer"; + let select = verified_only_select(sql); + assert_eq!( + &Expr::TryCast { + expr: Box::new(Expr::Identifier(Ident::new("id"))), + data_type: DataType::BigInt + }, + expr_from_projection(only(&select.projection)) + ); + one_statement_parses_to( + "SELECT TRY_CAST(id AS BIGINT) FROM customer", + "SELECT TRY_CAST(id AS BIGINT) FROM customer", + ); + + verified_stmt("SELECT TRY_CAST(id AS NUMERIC) FROM customer"); + + one_statement_parses_to( + "SELECT TRY_CAST(id AS DEC) FROM customer", + "SELECT TRY_CAST(id AS NUMERIC) FROM customer", + ); + + one_statement_parses_to( + "SELECT TRY_CAST(id AS DECIMAL) FROM customer", + "SELECT TRY_CAST(id AS NUMERIC) FROM customer", + ); +} + #[test] fn parse_extract() { let sql = "SELECT EXTRACT(YEAR FROM d)"; @@ -1224,6 +1253,7 @@ fn parse_assert() { } #[test] +#[allow(clippy::collapsible_match)] fn parse_assert_message() { let sql = "ASSERT (SELECT COUNT(*) FROM my_table) > 0 AS 'No rows in my_table'"; let ast = one_statement_parses_to(