From 97e261c8b4844d1c63228d68a8321707a0b66baa Mon Sep 17 00:00:00 2001 From: Yoav Cohen Date: Tue, 7 Jan 2025 15:30:02 +0100 Subject: [PATCH] Add support for MS-SQL BEGIN/END TRY/CATCH --- src/ast/helpers/stmt_create_table.rs | 6 +++- src/ast/mod.rs | 41 ++++++++++++++++++++++---- src/dialect/mod.rs | 7 ++++- src/dialect/mssql.rs | 7 +++++ src/keywords.rs | 2 ++ src/parser/mod.rs | 17 +++++++++++ tests/sqlparser_common.rs | 43 ++++++++++++++++++++++------ tests/sqlparser_custom_dialect.rs | 6 +++- tests/sqlparser_sqlite.rs | 8 +----- 9 files changed, 112 insertions(+), 25 deletions(-) diff --git a/src/ast/helpers/stmt_create_table.rs b/src/ast/helpers/stmt_create_table.rs index 364969c40..a3be57986 100644 --- a/src/ast/helpers/stmt_create_table.rs +++ b/src/ast/helpers/stmt_create_table.rs @@ -548,7 +548,11 @@ mod tests { #[test] pub fn test_from_invalid_statement() { - let stmt = Statement::Commit { chain: false }; + let stmt = Statement::Commit { + chain: false, + end: false, + modifier: None, + }; assert_eq!( CreateTableBuilder::try_from(stmt).unwrap_err(), diff --git a/src/ast/mod.rs b/src/ast/mod.rs index f46438b3e..9a6035c43 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2958,7 +2958,6 @@ pub enum Statement { modes: Vec, begin: bool, transaction: Option, - /// Only for SQLite modifier: Option, }, /// ```sql @@ -2985,7 +2984,17 @@ pub enum Statement { /// ```sql /// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] /// ``` - Commit { chain: bool }, + /// If `end` is false + /// + /// ```sql + /// END [ TRY | CATCH ] + /// ``` + /// If `end` is true + Commit { + chain: bool, + end: bool, + modifier: Option, + }, /// ```sql /// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ] /// ``` @@ -4614,8 +4623,23 @@ impl fmt::Display for Statement { } Ok(()) } - Statement::Commit { chain } => { - write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },) + Statement::Commit { + chain, + end: end_syntax, + modifier, + } => { + if *end_syntax { + write!(f, "END")?; + if let Some(modifier) = *modifier { + write!(f, " {}", modifier)?; + } + if *chain { + write!(f, " AND CHAIN")?; + } + } else { + write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" })?; + } + Ok(()) } Statement::Rollback { chain, savepoint } => { write!(f, "ROLLBACK")?; @@ -6388,9 +6412,10 @@ impl fmt::Display for TransactionIsolationLevel { } } -/// SQLite specific syntax +/// Modifier for the transaction in the `BEGIN` syntax /// -/// +/// SQLite: +/// MS-SQL: #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -6398,6 +6423,8 @@ pub enum TransactionModifier { Deferred, Immediate, Exclusive, + Try, + Catch, } impl fmt::Display for TransactionModifier { @@ -6407,6 +6434,8 @@ impl fmt::Display for TransactionModifier { Deferred => "DEFERRED", Immediate => "IMMEDIATE", Exclusive => "EXCLUSIVE", + Try => "TRY", + Catch => "CATCH", }) } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 7b14f2db5..4c3f0b4b2 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -260,11 +260,16 @@ pub trait Dialect: Debug + Any { false } - /// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE} [TRANSACTION]` statements + /// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE | TRY | CATCH} [TRANSACTION]` statements fn supports_start_transaction_modifier(&self) -> bool { false } + /// Returns true if the dialect supports `END {TRY | CATCH}` statements + fn supports_end_transaction_modifier(&self) -> bool { + false + } + /// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`. fn supports_named_fn_args_with_eq_operator(&self) -> bool { false diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 2d0ef027f..fa77bdc1e 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -78,4 +78,11 @@ impl Dialect for MsSqlDialect { fn supports_named_fn_args_with_rarrow_operator(&self) -> bool { false } + + fn supports_start_transaction_modifier(&self) -> bool { + true + } + fn supports_end_transaction_modifier(&self) -> bool { + true + } } diff --git a/src/keywords.rs b/src/keywords.rs index b7ff39e04..0a4d84a40 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -151,6 +151,7 @@ define_keywords!( CASE, CAST, CATALOG, + CATCH, CEIL, CEILING, CENTURY, @@ -808,6 +809,7 @@ define_keywords!( TRIM_ARRAY, TRUE, TRUNCATE, + TRY, TRY_CAST, TRY_CONVERT, TUPLE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 85ae66399..1cf422fdf 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -12786,6 +12786,10 @@ impl<'a> Parser<'a> { Some(TransactionModifier::Immediate) } else if self.parse_keyword(Keyword::EXCLUSIVE) { Some(TransactionModifier::Exclusive) + } else if self.parse_keyword(Keyword::TRY) { + Some(TransactionModifier::Try) + } else if self.parse_keyword(Keyword::CATCH) { + Some(TransactionModifier::Catch) } else { None }; @@ -12803,8 +12807,19 @@ impl<'a> Parser<'a> { } pub fn parse_end(&mut self) -> Result { + let modifier = if !self.dialect.supports_end_transaction_modifier() { + None + } else if self.parse_keyword(Keyword::TRY) { + Some(TransactionModifier::Try) + } else if self.parse_keyword(Keyword::CATCH) { + Some(TransactionModifier::Catch) + } else { + None + }; Ok(Statement::Commit { chain: self.parse_commit_rollback_chain()?, + end: true, + modifier, }) } @@ -12847,6 +12862,8 @@ impl<'a> Parser<'a> { pub fn parse_commit(&mut self) -> Result { Ok(Statement::Commit { chain: self.parse_commit_rollback_chain()?, + end: false, + modifier: None, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 7c8fd05a8..0fe11f139 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7854,6 +7854,27 @@ fn parse_start_transaction() { ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()), res.unwrap_err() ); + + // MS-SQL syntax + let dialects = all_dialects_where(|d| d.supports_start_transaction_modifier()); + dialects.verified_stmt("BEGIN TRY"); + dialects.verified_stmt("BEGIN CATCH"); + + let dialects = all_dialects_where(|d| { + d.supports_start_transaction_modifier() && d.supports_end_transaction_modifier() + }); + dialects + .parse_sql_statements( + r#" + BEGIN TRY; + SELECT 1/0; + END TRY; + BEGIN CATCH; + EXECUTE foo; + END CATCH; + "#, + ) + .unwrap(); } #[test] @@ -8069,12 +8090,12 @@ fn parse_set_time_zone_alias() { #[test] fn parse_commit() { match verified_stmt("COMMIT") { - Statement::Commit { chain: false } => (), + Statement::Commit { chain: false, .. } => (), _ => unreachable!(), } match verified_stmt("COMMIT AND CHAIN") { - Statement::Commit { chain: true } => (), + Statement::Commit { chain: true, .. } => (), _ => unreachable!(), } @@ -8089,13 +8110,17 @@ fn parse_commit() { #[test] fn parse_end() { - one_statement_parses_to("END AND NO CHAIN", "COMMIT"); - one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT"); - one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT"); - one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN"); - one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN"); - one_statement_parses_to("END WORK", "COMMIT"); - one_statement_parses_to("END TRANSACTION", "COMMIT"); + one_statement_parses_to("END AND NO CHAIN", "END"); + one_statement_parses_to("END WORK AND NO CHAIN", "END"); + one_statement_parses_to("END TRANSACTION AND NO CHAIN", "END"); + one_statement_parses_to("END WORK AND CHAIN", "END AND CHAIN"); + one_statement_parses_to("END TRANSACTION AND CHAIN", "END AND CHAIN"); + one_statement_parses_to("END WORK", "END"); + one_statement_parses_to("END TRANSACTION", "END"); + // MS-SQL syntax + let dialects = all_dialects_where(|d| d.supports_end_transaction_modifier()); + dialects.verified_stmt("END TRY"); + dialects.verified_stmt("END CATCH"); } #[test] diff --git a/tests/sqlparser_custom_dialect.rs b/tests/sqlparser_custom_dialect.rs index e9ca82aba..61874fc27 100644 --- a/tests/sqlparser_custom_dialect.rs +++ b/tests/sqlparser_custom_dialect.rs @@ -115,7 +115,11 @@ fn custom_statement_parser() -> Result<(), ParserError> { for _ in 0..3 { let _ = parser.next_token(); } - Some(Ok(Statement::Commit { chain: false })) + Some(Ok(Statement::Commit { + chain: false, + end: false, + modifier: None, + })) } else { None } diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 0adf7f755..edd1365f4 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -523,13 +523,7 @@ fn parse_start_transaction_with_modifier() { sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE"); sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE"); - let unsupported_dialects = TestedDialects::new( - all_dialects() - .dialects - .into_iter() - .filter(|x| !(x.is::() || x.is::())) - .collect(), - ); + let unsupported_dialects = all_dialects_except(|d| d.supports_start_transaction_modifier()); let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED"); assert_eq!( ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()),