Skip to content

Implement TRY_CAST #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ pub enum Expr {
expr: Box<Expr>,
data_type: DataType,
},
/// TRY_CAST an expression to a different data type e.g. `TRY_CAST(foo AS VARCHAR(123))`
// this differs from CAST in the choice of how to implement invalid conversions
TryCast {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering, why the choice TryCast / try_cast?

I know bigquery uses SAFE_CAST https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-and-operators#safe_casting

It would make sense to support both maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could easily change this to add SAFE_CAST too. There was some rationale discussed here: apache/arrow#9682 (comment)

basically it was just that I think SQL Server got there first with TRY_CAST

expr: Box<Expr>,
data_type: DataType,
},
/// EXTRACT(DateTimeField FROM <expr>)
Extract {
field: DateTimeField,
Expand Down Expand Up @@ -309,6 +315,7 @@ impl fmt::Display for Expr {
}
}
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
Expr::TryCast { expr, data_type } => write!(f, "TRY_CAST({} AS {})", expr, data_type),
Expr::Extract { field, expr } => write!(f, "EXTRACT({} FROM {})", field, expr),
Expr::Collate { expr, collation } => write!(f, "{} COLLATE {}", expr, collation),
Expr::Nested(ast) => write!(f, "({})", ast),
Expand Down
1 change: 1 addition & 0 deletions src/dialect/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ define_keywords!(
TRIM_ARRAY,
TRUE,
TRUNCATE,
TRY_CAST,
UESCAPE,
UNBOUNDED,
UNCOMMITTED,
Expand Down
27 changes: 20 additions & 7 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ impl<'a> Parser<'a> {
}
Keyword::CASE => self.parse_case_expr(),
Keyword::CAST => self.parse_cast_expr(),
Keyword::TRY_CAST => self.parse_try_cast_expr(),
Keyword::EXISTS => self.parse_exists_expr(),
Keyword::EXTRACT => self.parse_extract_expr(),
Keyword::SUBSTRING => self.parse_substring_expr(),
Expand Down Expand Up @@ -591,6 +592,19 @@ impl<'a> Parser<'a> {
})
}

/// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)`
pub fn parse_try_cast_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let expr = self.parse_expr()?;
self.expect_keyword(Keyword::AS)?;
let data_type = self.parse_data_type()?;
self.expect_token(&Token::RParen)?;
Ok(Expr::TryCast {
expr: Box::new(expr),
data_type,
})
}

/// Parse a SQL EXISTS expression e.g. `WHERE EXISTS(SELECT ...)`.
pub fn parse_exists_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
Expand Down Expand Up @@ -1806,7 +1820,7 @@ impl<'a> Parser<'a> {
let columns = self.parse_parenthesized_column_list(Optional)?;
self.expect_keywords(&[Keyword::FROM, Keyword::STDIN])?;
self.expect_token(&Token::SemiColon)?;
let values = self.parse_tsv()?;
let values = self.parse_tsv();
Ok(Statement::Copy {
table_name,
columns,
Expand All @@ -1816,12 +1830,11 @@ impl<'a> Parser<'a> {

/// Parse a tab separated values in
/// COPY payload
fn parse_tsv(&mut self) -> Result<Vec<Option<String>>, ParserError> {
let values = self.parse_tab_value()?;
Ok(values)
fn parse_tsv(&mut self) -> Vec<Option<String>> {
self.parse_tab_value()
}

fn parse_tab_value(&mut self) -> Result<Vec<Option<String>>, ParserError> {
fn parse_tab_value(&mut self) -> Vec<Option<String>> {
let mut values = vec![];
let mut content = String::from("");
while let Some(t) = self.next_token_no_skip() {
Expand All @@ -1836,7 +1849,7 @@ impl<'a> Parser<'a> {
}
Token::Backslash => {
if self.consume_token(&Token::Period) {
return Ok(values);
return values;
}
if let Token::Word(w) = self.next_token() {
if w.value == "N" {
Expand All @@ -1849,7 +1862,7 @@ impl<'a> Parser<'a> {
}
}
}
Ok(values)
values
}

/// Parse a literal value (numbers, strings, date/time, booleans)
Expand Down
1 change: 1 addition & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ impl<'a> Tokenizer<'a> {
}
}

#[allow(clippy::unnecessary_wraps)]
fn consume_and_return(
&self,
chars: &mut Peekable<Chars<'_>>,
Expand Down
30 changes: 30 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,35 @@ fn parse_cast() {
);
}

#[test]
fn parse_try_cast() {
let sql = "SELECT TRY_CAST(id AS BIGINT) FROM customer";
let select = verified_only_select(sql);
assert_eq!(
&Expr::TryCast {
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::BigInt
},
expr_from_projection(only(&select.projection))
);
one_statement_parses_to(
"SELECT TRY_CAST(id AS BIGINT) FROM customer",
"SELECT TRY_CAST(id AS BIGINT) FROM customer",
);

verified_stmt("SELECT TRY_CAST(id AS NUMERIC) FROM customer");

one_statement_parses_to(
"SELECT TRY_CAST(id AS DEC) FROM customer",
"SELECT TRY_CAST(id AS NUMERIC) FROM customer",
);

one_statement_parses_to(
"SELECT TRY_CAST(id AS DECIMAL) FROM customer",
"SELECT TRY_CAST(id AS NUMERIC) FROM customer",
);
}

#[test]
fn parse_extract() {
let sql = "SELECT EXTRACT(YEAR FROM d)";
Expand Down Expand Up @@ -1224,6 +1253,7 @@ fn parse_assert() {
}

#[test]
#[allow(clippy::collapsible_match)]
fn parse_assert_message() {
let sql = "ASSERT (SELECT COUNT(*) FROM my_table) > 0 AS 'No rows in my_table'";
let ast = one_statement_parses_to(
Expand Down