diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 159e14717..b622c1da3 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -681,6 +681,12 @@ pub trait Dialect: Debug + Any { fn supports_partiql(&self) -> bool { false } + + /// Returns true if the specified keyword is reserved and cannot be + /// used as an identifier without special handling like quoting. + fn is_reserved_for_identifier(&self, kw: Keyword) -> bool { + keywords::RESERVED_FOR_IDENTIFIER.contains(&kw) + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index b584ed9b4..56919fb31 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -38,6 +38,8 @@ use alloc::vec::Vec; #[cfg(not(feature = "std"))] use alloc::{format, vec}; +use super::keywords::RESERVED_FOR_IDENTIFIER; + /// A [`Dialect`] for [Snowflake](https://www.snowflake.com/) #[derive(Debug, Default)] pub struct SnowflakeDialect; @@ -214,6 +216,16 @@ impl Dialect for SnowflakeDialect { fn supports_show_like_before_in(&self) -> bool { true } + + fn is_reserved_for_identifier(&self, kw: Keyword) -> bool { + // Unreserve some keywords that Snowflake accepts as identifiers + // See: https://docs.snowflake.com/en/sql-reference/reserved-keywords + if matches!(kw, Keyword::INTERVAL) { + false + } else { + RESERVED_FOR_IDENTIFIER.contains(&kw) + } + } } /// Parse snowflake create table statement. diff --git a/src/keywords.rs b/src/keywords.rs index fc2a2927c..8c0ed588f 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -948,3 +948,13 @@ pub const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[ Keyword::INTO, Keyword::END, ]; + +/// Global list of reserved keywords that cannot be parsed as identifiers +/// without special handling like quoting. Parser should call `Dialect::is_reserved_for_identifier` +/// to allow for each dialect to customize the list. +pub const RESERVED_FOR_IDENTIFIER: &[Keyword] = &[ + Keyword::EXISTS, + Keyword::INTERVAL, + Keyword::STRUCT, + Keyword::TRIM, +]; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1bf173169..6767f358a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1025,6 +1025,183 @@ impl<'a> Parser<'a> { Ok(Statement::NOTIFY { channel, payload }) } + // Tries to parse an expression by matching the specified word to known keywords that have a special meaning in the dialect. + // Returns `None if no match is found. + fn parse_expr_prefix_by_reserved_word( + &mut self, + w: &Word, + ) -> Result, ParserError> { + match w.keyword { + Keyword::TRUE | Keyword::FALSE if self.dialect.supports_boolean_literals() => { + self.prev_token(); + Ok(Some(Expr::Value(self.parse_value()?))) + } + Keyword::NULL => { + self.prev_token(); + Ok(Some(Expr::Value(self.parse_value()?))) + } + Keyword::CURRENT_CATALOG + | Keyword::CURRENT_USER + | Keyword::SESSION_USER + | Keyword::USER + if dialect_of!(self is PostgreSqlDialect | GenericDialect) => + { + Ok(Some(Expr::Function(Function { + name: ObjectName(vec![w.to_ident()]), + parameters: FunctionArguments::None, + args: FunctionArguments::None, + null_treatment: None, + filter: None, + over: None, + within_group: vec![], + }))) + } + Keyword::CURRENT_TIMESTAMP + | Keyword::CURRENT_TIME + | Keyword::CURRENT_DATE + | Keyword::LOCALTIME + | Keyword::LOCALTIMESTAMP => { + Ok(Some(self.parse_time_functions(ObjectName(vec![w.to_ident()]))?)) + } + Keyword::CASE => Ok(Some(self.parse_case_expr()?)), + Keyword::CONVERT => Ok(Some(self.parse_convert_expr(false)?)), + Keyword::TRY_CONVERT if self.dialect.supports_try_convert() => Ok(Some(self.parse_convert_expr(true)?)), + Keyword::CAST => Ok(Some(self.parse_cast_expr(CastKind::Cast)?)), + Keyword::TRY_CAST => Ok(Some(self.parse_cast_expr(CastKind::TryCast)?)), + Keyword::SAFE_CAST => Ok(Some(self.parse_cast_expr(CastKind::SafeCast)?)), + Keyword::EXISTS + // Support parsing Databricks has a function named `exists`. + if !dialect_of!(self is DatabricksDialect) + || matches!( + self.peek_nth_token(1).token, + Token::Word(Word { + keyword: Keyword::SELECT | Keyword::WITH, + .. + }) + ) => + { + Ok(Some(self.parse_exists_expr(false)?)) + } + Keyword::EXTRACT => Ok(Some(self.parse_extract_expr()?)), + Keyword::CEIL => Ok(Some(self.parse_ceil_floor_expr(true)?)), + Keyword::FLOOR => Ok(Some(self.parse_ceil_floor_expr(false)?)), + Keyword::POSITION if self.peek_token().token == Token::LParen => { + Ok(Some(self.parse_position_expr(w.to_ident())?)) + } + Keyword::SUBSTRING => Ok(Some(self.parse_substring_expr()?)), + Keyword::OVERLAY => Ok(Some(self.parse_overlay_expr()?)), + Keyword::TRIM => Ok(Some(self.parse_trim_expr()?)), + Keyword::INTERVAL => Ok(Some(self.parse_interval()?)), + // Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as subquery or a function call + Keyword::ARRAY if self.peek_token() == Token::LBracket => { + self.expect_token(&Token::LBracket)?; + Ok(Some(self.parse_array_expr(true)?)) + } + Keyword::ARRAY + if self.peek_token() == Token::LParen + && !dialect_of!(self is ClickHouseDialect | DatabricksDialect) => + { + self.expect_token(&Token::LParen)?; + let query = self.parse_query()?; + self.expect_token(&Token::RParen)?; + Ok(Some(Expr::Function(Function { + name: ObjectName(vec![w.to_ident()]), + parameters: FunctionArguments::None, + args: FunctionArguments::Subquery(query), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }))) + } + Keyword::NOT => Ok(Some(self.parse_not()?)), + Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => { + Ok(Some(self.parse_match_against()?)) + } + Keyword::STRUCT if dialect_of!(self is BigQueryDialect | GenericDialect) => { + self.prev_token(); + Ok(Some(self.parse_bigquery_struct_literal()?)) + } + Keyword::PRIOR if matches!(self.state, ParserState::ConnectBy) => { + let expr = self.parse_subexpr(self.dialect.prec_value(Precedence::PlusMinus))?; + Ok(Some(Expr::Prior(Box::new(expr)))) + } + Keyword::MAP if self.peek_token() == Token::LBrace && self.dialect.support_map_literal_syntax() => { + Ok(Some(self.parse_duckdb_map_literal()?)) + } + _ => Ok(None) + } + } + + // Tries to parse an expression by a word that is not known to have a special meaning in the dialect. + fn parse_expr_prefix_by_unreserved_word(&mut self, w: &Word) -> Result { + match self.peek_token().token { + Token::LParen | Token::Period => { + let mut id_parts: Vec = vec![w.to_ident()]; + let mut ends_with_wildcard = false; + while self.consume_token(&Token::Period) { + let next_token = self.next_token(); + match next_token.token { + Token::Word(w) => id_parts.push(w.to_ident()), + Token::Mul => { + // Postgres explicitly allows funcnm(tablenm.*) and the + // function array_agg traverses this control flow + if dialect_of!(self is PostgreSqlDialect) { + ends_with_wildcard = true; + break; + } else { + return self.expected("an identifier after '.'", next_token); + } + } + Token::SingleQuotedString(s) => id_parts.push(Ident::with_quote('\'', s)), + _ => { + return self.expected("an identifier or a '*' after '.'", next_token); + } + } + } + + if ends_with_wildcard { + Ok(Expr::QualifiedWildcard(ObjectName(id_parts))) + } else if self.consume_token(&Token::LParen) { + if dialect_of!(self is SnowflakeDialect | MsSqlDialect) + && self.consume_tokens(&[Token::Plus, Token::RParen]) + { + Ok(Expr::OuterJoin(Box::new( + match <[Ident; 1]>::try_from(id_parts) { + Ok([ident]) => Expr::Identifier(ident), + Err(parts) => Expr::CompoundIdentifier(parts), + }, + ))) + } else { + self.prev_token(); + self.parse_function(ObjectName(id_parts)) + } + } else { + Ok(Expr::CompoundIdentifier(id_parts)) + } + } + // string introducer https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html + Token::SingleQuotedString(_) + | Token::DoubleQuotedString(_) + | Token::HexStringLiteral(_) + if w.value.starts_with('_') => + { + Ok(Expr::IntroducedString { + introducer: w.value.clone(), + value: self.parse_introduced_string_value()?, + }) + } + Token::Arrow if self.dialect.supports_lambda_functions() => { + self.expect_token(&Token::Arrow)?; + Ok(Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::One(w.to_ident()), + body: Box::new(self.parse_expr()?), + })) + } + _ => Ok(Expr::Identifier(w.to_ident())), + } + } + /// Parse an expression prefix. pub fn parse_prefix(&mut self) -> Result { // allow the dialect to override prefix parsing @@ -1073,176 +1250,40 @@ impl<'a> Parser<'a> { let next_token = self.next_token(); let expr = match next_token.token { - Token::Word(w) => match w.keyword { - Keyword::TRUE | Keyword::FALSE if self.dialect.supports_boolean_literals() => { - self.prev_token(); - Ok(Expr::Value(self.parse_value()?)) - } - Keyword::NULL => { - self.prev_token(); - Ok(Expr::Value(self.parse_value()?)) - } - Keyword::CURRENT_CATALOG - | Keyword::CURRENT_USER - | Keyword::SESSION_USER - | Keyword::USER - if dialect_of!(self is PostgreSqlDialect | GenericDialect) => - { - Ok(Expr::Function(Function { - name: ObjectName(vec![w.to_ident()]), - parameters: FunctionArguments::None, - args: FunctionArguments::None, - null_treatment: None, - filter: None, - over: None, - within_group: vec![], - })) - } - Keyword::CURRENT_TIMESTAMP - | Keyword::CURRENT_TIME - | Keyword::CURRENT_DATE - | Keyword::LOCALTIME - | Keyword::LOCALTIMESTAMP => { - self.parse_time_functions(ObjectName(vec![w.to_ident()])) - } - Keyword::CASE => self.parse_case_expr(), - Keyword::CONVERT => self.parse_convert_expr(false), - Keyword::TRY_CONVERT if self.dialect.supports_try_convert() => self.parse_convert_expr(true), - Keyword::CAST => self.parse_cast_expr(CastKind::Cast), - Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast), - Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast), - Keyword::EXISTS - // Support parsing Databricks has a function named `exists`. - if !dialect_of!(self is DatabricksDialect) - || matches!( - self.peek_nth_token(1).token, - Token::Word(Word { - keyword: Keyword::SELECT | Keyword::WITH, - .. - }) - ) => - { - self.parse_exists_expr(false) - } - Keyword::EXTRACT => self.parse_extract_expr(), - Keyword::CEIL => self.parse_ceil_floor_expr(true), - Keyword::FLOOR => self.parse_ceil_floor_expr(false), - Keyword::POSITION if self.peek_token().token == Token::LParen => { - self.parse_position_expr(w.to_ident()) - } - Keyword::SUBSTRING => self.parse_substring_expr(), - Keyword::OVERLAY => self.parse_overlay_expr(), - Keyword::TRIM => self.parse_trim_expr(), - Keyword::INTERVAL => self.parse_interval(), - // Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as subquery or a function call - Keyword::ARRAY if self.peek_token() == Token::LBracket => { - self.expect_token(&Token::LBracket)?; - self.parse_array_expr(true) - } - Keyword::ARRAY - if self.peek_token() == Token::LParen - && !dialect_of!(self is ClickHouseDialect | DatabricksDialect) => - { - self.expect_token(&Token::LParen)?; - let query = self.parse_query()?; - self.expect_token(&Token::RParen)?; - Ok(Expr::Function(Function { - name: ObjectName(vec![w.to_ident()]), - parameters: FunctionArguments::None, - args: FunctionArguments::Subquery(query), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - })) - } - Keyword::NOT => self.parse_not(), - Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => { - self.parse_match_against() - } - Keyword::STRUCT if dialect_of!(self is BigQueryDialect | GenericDialect) => { - self.prev_token(); - self.parse_bigquery_struct_literal() - } - Keyword::PRIOR if matches!(self.state, ParserState::ConnectBy) => { - let expr = self.parse_subexpr(self.dialect.prec_value(Precedence::PlusMinus))?; - Ok(Expr::Prior(Box::new(expr))) - } - Keyword::MAP if self.peek_token() == Token::LBrace && self.dialect.support_map_literal_syntax() => { - self.parse_duckdb_map_literal() - } - // Here `w` is a word, check if it's a part of a multipart - // identifier, a function call, or a simple identifier: - _ => match self.peek_token().token { - Token::LParen | Token::Period => { - let mut id_parts: Vec = vec![w.to_ident()]; - let mut ends_with_wildcard = false; - while self.consume_token(&Token::Period) { - let next_token = self.next_token(); - match next_token.token { - Token::Word(w) => id_parts.push(w.to_ident()), - Token::Mul => { - // Postgres explicitly allows funcnm(tablenm.*) and the - // function array_agg traverses this control flow - if dialect_of!(self is PostgreSqlDialect) { - ends_with_wildcard = true; - break; - } else { - return self - .expected("an identifier after '.'", next_token); - } - } - Token::SingleQuotedString(s) => { - id_parts.push(Ident::with_quote('\'', s)) - } - _ => { - return self - .expected("an identifier or a '*' after '.'", next_token); - } - } - } - - if ends_with_wildcard { - Ok(Expr::QualifiedWildcard(ObjectName(id_parts))) - } else if self.consume_token(&Token::LParen) { - if dialect_of!(self is SnowflakeDialect | MsSqlDialect) - && self.consume_tokens(&[Token::Plus, Token::RParen]) - { - Ok(Expr::OuterJoin(Box::new( - match <[Ident; 1]>::try_from(id_parts) { - Ok([ident]) => Expr::Identifier(ident), - Err(parts) => Expr::CompoundIdentifier(parts), - }, - ))) - } else { - self.prev_token(); - self.parse_function(ObjectName(id_parts)) + Token::Word(w) => { + // The word we consumed may fall into one of two cases: it has a special meaning, or not. + // For example, in Snowflake, the word `interval` may have two meanings depending on the context: + // `SELECT CURRENT_DATE() + INTERVAL '1 DAY', MAX(interval) FROM tbl;` + // ^^^^^^^^^^^^^^^^ ^^^^^^^^ + // interval expression identifier + // + // We first try to parse the word and following tokens as a special expression, and if that fails, + // we rollback and try to parse it as an identifier. + match self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w)) { + // This word indicated an expression prefix and parsing was successful + Ok(Some(expr)) => Ok(expr), + + // No expression prefix associated with this word + Ok(None) => Ok(self.parse_expr_prefix_by_unreserved_word(&w)?), + + // If parsing of the word as a special expression failed, we are facing two options: + // 1. The statement is malformed, e.g. `SELECT INTERVAL '1 DAI` (`DAI` instead of `DAY`) + // 2. The word is used as an identifier, e.g. `SELECT MAX(interval) FROM tbl` + // We first try to parse the word as an identifier and if that fails + // we rollback and return the parsing error we got from trying to parse a + // special expression (to maintain backwards compatibility of parsing errors). + Err(e) => { + if !self.dialect.is_reserved_for_identifier(w.keyword) { + if let Ok(Some(expr)) = self.maybe_parse(|parser| { + parser.parse_expr_prefix_by_unreserved_word(&w) + }) { + return Ok(expr); } - } else { - Ok(Expr::CompoundIdentifier(id_parts)) } + return Err(e); } - // string introducer https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html - Token::SingleQuotedString(_) - | Token::DoubleQuotedString(_) - | Token::HexStringLiteral(_) - if w.value.starts_with('_') => - { - Ok(Expr::IntroducedString { - introducer: w.value, - value: self.parse_introduced_string_value()?, - }) - } - Token::Arrow if self.dialect.supports_lambda_functions() => { - self.expect_token(&Token::Arrow)?; - return Ok(Expr::Lambda(LambdaFunction { - params: OneOrManyWithParens::One(w.to_ident()), - body: Box::new(self.parse_expr()?), - })); - } - _ => Ok(Expr::Identifier(w.to_ident())), - }, - }, // End of Token::Word + } + } // End of Token::Word // array `[1, 2, 3]` Token::LBracket => self.parse_array_expr(false), tok @ Token::Minus | tok @ Token::Plus => { @@ -3677,18 +3718,30 @@ impl<'a> Parser<'a> { } /// Run a parser method `f`, reverting back to the current position if unsuccessful. - pub fn maybe_parse(&mut self, mut f: F) -> Result, ParserError> + /// Returns `None` if `f` returns an error + pub fn maybe_parse(&mut self, f: F) -> Result, ParserError> where F: FnMut(&mut Parser) -> Result, { - let index = self.index; - match f(self) { + match self.try_parse(f) { Ok(t) => Ok(Some(t)), - // Unwind stack if limit exceeded Err(ParserError::RecursionLimitExceeded) => Err(ParserError::RecursionLimitExceeded), - Err(_) => { + _ => Ok(None), + } + } + + /// Run a parser method `f`, reverting back to the current position if unsuccessful. + pub fn try_parse(&mut self, mut f: F) -> Result + where + F: FnMut(&mut Parser) -> Result, + { + let index = self.index; + match f(self) { + Ok(t) => Ok(t), + Err(e) => { + // Unwind stack if limit exceeded self.index = index; - Ok(None) + Err(e) } } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b41063859..c03370892 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -34,7 +34,7 @@ use sqlparser::dialect::{ GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, RedshiftSqlDialect, SQLiteDialect, SnowflakeDialect, }; -use sqlparser::keywords::ALL_KEYWORDS; +use sqlparser::keywords::{Keyword, ALL_KEYWORDS}; use sqlparser::parser::{Parser, ParserError, ParserOptions}; use sqlparser::tokenizer::Tokenizer; use test_utils::{ @@ -5113,7 +5113,6 @@ fn parse_interval_dont_require_unit() { #[test] fn parse_interval_require_unit() { let dialects = all_dialects_where(|d| d.require_interval_qualifier()); - let sql = "SELECT INTERVAL '1 DAY'"; let err = dialects.parse_sql_statements(sql).unwrap_err(); assert_eq!( @@ -12198,3 +12197,21 @@ fn parse_create_table_select() { ); } } + +#[test] +fn test_reserved_keywords_for_identifiers() { + let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL)); + // Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier + let sql = "SELECT MAX(interval) FROM tbl"; + assert_eq!( + dialects.parse_sql_statements(sql), + Err(ParserError::ParserError( + "Expected: an expression, found: )".to_string() + )) + ); + + // Dialects that do not reserve the word INTERVAL will allow it + let dialects = all_dialects_where(|d| !d.is_reserved_for_identifier(Keyword::INTERVAL)); + let sql = "SELECT MAX(interval) FROM tbl"; + dialects.parse_sql_statements(sql).unwrap(); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 098a3464c..d27569e03 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1352,10 +1352,7 @@ fn parse_set() { local: false, hivevar: false, variables: OneOrManyWithParens::One(ObjectName(vec![Ident::new("a")])), - value: vec![Expr::Identifier(Ident { - value: "DEFAULT".into(), - quote_style: None - })], + value: vec![Expr::Identifier(Ident::new("DEFAULT"))], } ); @@ -4229,10 +4226,7 @@ fn test_simple_postgres_insert_with_alias() { body: Box::new(SetExpr::Values(Values { explicit_row: false, rows: vec![vec![ - Expr::Identifier(Ident { - value: "DEFAULT".to_string(), - quote_style: None - }), + Expr::Identifier(Ident::new("DEFAULT")), Expr::Value(Value::Number("123".to_string(), false)) ]] })), @@ -4295,10 +4289,7 @@ fn test_simple_postgres_insert_with_alias() { body: Box::new(SetExpr::Values(Values { explicit_row: false, rows: vec![vec![ - Expr::Identifier(Ident { - value: "DEFAULT".to_string(), - quote_style: None - }), + Expr::Identifier(Ident::new("DEFAULT")), Expr::Value(Value::Number( bigdecimal::BigDecimal::new(123.into(), 0), false @@ -4363,10 +4354,7 @@ fn test_simple_insert_with_quoted_alias() { body: Box::new(SetExpr::Values(Values { explicit_row: false, rows: vec![vec![ - Expr::Identifier(Ident { - value: "DEFAULT".to_string(), - quote_style: None - }), + Expr::Identifier(Ident::new("DEFAULT")), Expr::Value(Value::SingleQuotedString("0123".to_string())) ]] })),