Skip to content

Commit 5bdf2e6

Browse files
authored
Add support for release and rollback to savepoint syntax (#1045)
1 parent c905ee0 commit 5bdf2e6

File tree

4 files changed

+100
-19
lines changed

4 files changed

+100
-19
lines changed

src/ast/mod.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,8 +1839,11 @@ pub enum Statement {
18391839
},
18401840
/// `COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]`
18411841
Commit { chain: bool },
1842-
/// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]`
1843-
Rollback { chain: bool },
1842+
/// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]`
1843+
Rollback {
1844+
chain: bool,
1845+
savepoint: Option<Ident>,
1846+
},
18441847
/// CREATE SCHEMA
18451848
CreateSchema {
18461849
/// `<schema name> | AUTHORIZATION <schema authorization identifier> | <schema name> AUTHORIZATION <schema authorization identifier>`
@@ -1977,6 +1980,8 @@ pub enum Statement {
19771980
},
19781981
/// SAVEPOINT -- define a new savepoint within the current transaction
19791982
Savepoint { name: Ident },
1983+
/// RELEASE \[ SAVEPOINT \] savepoint_name
1984+
ReleaseSavepoint { name: Ident },
19801985
// MERGE INTO statement, based on Snowflake. See <https://docs.snowflake.com/en/sql-reference/sql/merge.html>
19811986
Merge {
19821987
// optional INTO keyword
@@ -3127,8 +3132,18 @@ impl fmt::Display for Statement {
31273132
Statement::Commit { chain } => {
31283133
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },)
31293134
}
3130-
Statement::Rollback { chain } => {
3131-
write!(f, "ROLLBACK{}", if *chain { " AND CHAIN" } else { "" },)
3135+
Statement::Rollback { chain, savepoint } => {
3136+
write!(f, "ROLLBACK")?;
3137+
3138+
if *chain {
3139+
write!(f, " AND CHAIN")?;
3140+
}
3141+
3142+
if let Some(savepoint) = savepoint {
3143+
write!(f, " TO SAVEPOINT {savepoint}")?;
3144+
}
3145+
3146+
Ok(())
31323147
}
31333148
Statement::CreateSchema {
31343149
schema_name,
@@ -3225,6 +3240,9 @@ impl fmt::Display for Statement {
32253240
write!(f, "SAVEPOINT ")?;
32263241
write!(f, "{name}")
32273242
}
3243+
Statement::ReleaseSavepoint { name } => {
3244+
write!(f, "RELEASE SAVEPOINT {name}")
3245+
}
32283246
Statement::Merge {
32293247
into,
32303248
table,

src/parser/mod.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ impl<'a> Parser<'a> {
502502
// by at least PostgreSQL and MySQL.
503503
Keyword::BEGIN => Ok(self.parse_begin()?),
504504
Keyword::SAVEPOINT => Ok(self.parse_savepoint()?),
505+
Keyword::RELEASE => Ok(self.parse_release()?),
505506
Keyword::COMMIT => Ok(self.parse_commit()?),
506507
Keyword::ROLLBACK => Ok(self.parse_rollback()?),
507508
Keyword::ASSERT => Ok(self.parse_assert()?),
@@ -747,6 +748,13 @@ impl<'a> Parser<'a> {
747748
Ok(Statement::Savepoint { name })
748749
}
749750

751+
pub fn parse_release(&mut self) -> Result<Statement, ParserError> {
752+
let _ = self.parse_keyword(Keyword::SAVEPOINT);
753+
let name = self.parse_identifier()?;
754+
755+
Ok(Statement::ReleaseSavepoint { name })
756+
}
757+
750758
/// Parse an expression prefix
751759
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
752760
// allow the dialect to override prefix parsing
@@ -7843,9 +7851,10 @@ impl<'a> Parser<'a> {
78437851
}
78447852

78457853
pub fn parse_rollback(&mut self) -> Result<Statement, ParserError> {
7846-
Ok(Statement::Rollback {
7847-
chain: self.parse_commit_rollback_chain()?,
7848-
})
7854+
let chain = self.parse_commit_rollback_chain()?;
7855+
let savepoint = self.parse_rollback_savepoint()?;
7856+
7857+
Ok(Statement::Rollback { chain, savepoint })
78497858
}
78507859

78517860
pub fn parse_commit_rollback_chain(&mut self) -> Result<bool, ParserError> {
@@ -7859,6 +7868,17 @@ impl<'a> Parser<'a> {
78597868
}
78607869
}
78617870

7871+
pub fn parse_rollback_savepoint(&mut self) -> Result<Option<Ident>, ParserError> {
7872+
if self.parse_keyword(Keyword::TO) {
7873+
let _ = self.parse_keyword(Keyword::SAVEPOINT);
7874+
let savepoint = self.parse_identifier()?;
7875+
7876+
Ok(Some(savepoint))
7877+
} else {
7878+
Ok(None)
7879+
}
7880+
}
7881+
78627882
pub fn parse_deallocate(&mut self) -> Result<Statement, ParserError> {
78637883
let prepare = self.parse_keyword(Keyword::PREPARE);
78647884
let name = self.parse_identifier()?;

tests/sqlparser_common.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6234,12 +6234,38 @@ fn parse_commit() {
62346234
#[test]
62356235
fn parse_rollback() {
62366236
match verified_stmt("ROLLBACK") {
6237-
Statement::Rollback { chain: false } => (),
6237+
Statement::Rollback {
6238+
chain: false,
6239+
savepoint: None,
6240+
} => (),
62386241
_ => unreachable!(),
62396242
}
62406243

62416244
match verified_stmt("ROLLBACK AND CHAIN") {
6242-
Statement::Rollback { chain: true } => (),
6245+
Statement::Rollback {
6246+
chain: true,
6247+
savepoint: None,
6248+
} => (),
6249+
_ => unreachable!(),
6250+
}
6251+
6252+
match verified_stmt("ROLLBACK TO SAVEPOINT test1") {
6253+
Statement::Rollback {
6254+
chain: false,
6255+
savepoint,
6256+
} => {
6257+
assert_eq!(savepoint, Some(Ident::new("test1")));
6258+
}
6259+
_ => unreachable!(),
6260+
}
6261+
6262+
match verified_stmt("ROLLBACK AND CHAIN TO SAVEPOINT test1") {
6263+
Statement::Rollback {
6264+
chain: true,
6265+
savepoint,
6266+
} => {
6267+
assert_eq!(savepoint, Some(Ident::new("test1")));
6268+
}
62436269
_ => unreachable!(),
62446270
}
62456271

@@ -6250,6 +6276,11 @@ fn parse_rollback() {
62506276
one_statement_parses_to("ROLLBACK TRANSACTION AND CHAIN", "ROLLBACK AND CHAIN");
62516277
one_statement_parses_to("ROLLBACK WORK", "ROLLBACK");
62526278
one_statement_parses_to("ROLLBACK TRANSACTION", "ROLLBACK");
6279+
one_statement_parses_to("ROLLBACK TO test1", "ROLLBACK TO SAVEPOINT test1");
6280+
one_statement_parses_to(
6281+
"ROLLBACK AND CHAIN TO test1",
6282+
"ROLLBACK AND CHAIN TO SAVEPOINT test1",
6283+
);
62536284
}
62546285

62556286
#[test]
@@ -7864,3 +7895,25 @@ fn parse_binary_operators_without_whitespace() {
78647895
"SELECT tbl1.field % tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
78657896
);
78667897
}
7898+
7899+
#[test]
7900+
fn test_savepoint() {
7901+
match verified_stmt("SAVEPOINT test1") {
7902+
Statement::Savepoint { name } => {
7903+
assert_eq!(Ident::new("test1"), name);
7904+
}
7905+
_ => unreachable!(),
7906+
}
7907+
}
7908+
7909+
#[test]
7910+
fn test_release_savepoint() {
7911+
match verified_stmt("RELEASE SAVEPOINT test1") {
7912+
Statement::ReleaseSavepoint { name } => {
7913+
assert_eq!(Ident::new("test1"), name);
7914+
}
7915+
_ => unreachable!(),
7916+
}
7917+
7918+
one_statement_parses_to("RELEASE test1", "RELEASE SAVEPOINT test1");
7919+
}

tests/sqlparser_postgres.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,16 +2093,6 @@ fn test_transaction_statement() {
20932093
);
20942094
}
20952095

2096-
#[test]
2097-
fn test_savepoint() {
2098-
match pg().verified_stmt("SAVEPOINT test1") {
2099-
Statement::Savepoint { name } => {
2100-
assert_eq!(Ident::new("test1"), name);
2101-
}
2102-
_ => unreachable!(),
2103-
}
2104-
}
2105-
21062096
#[test]
21072097
fn test_json() {
21082098
let sql = "SELECT params ->> 'name' FROM events";

0 commit comments

Comments
 (0)