From 1a45316253752bc7c6961ecd9910fdab3f891f46 Mon Sep 17 00:00:00 2001 From: Elia Perantoni Date: Thu, 5 Jun 2025 10:32:55 +0200 Subject: [PATCH 1/3] fix: case expression spans to include leading and trailing keywords --- src/ast/mod.rs | 4 ++++ src/ast/spans.rs | 22 ++++++++++++++-------- src/parser/mod.rs | 5 ++++- tests/sqlparser_common.rs | 16 ++++++++++++++++ tests/sqlparser_databricks.rs | 3 +++ 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 711e580df..f1fc4760a 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -967,6 +967,8 @@ pub enum Expr { /// not `< 0` nor `1, 2, 3` as allowed in a `` per /// Case { + case_token: AttachedToken, + end_token: AttachedToken, operand: Option>, conditions: Vec, else_result: Option>, @@ -1675,6 +1677,8 @@ impl fmt::Display for Expr { } Expr::Function(fun) => fun.fmt(f), Expr::Case { + case_token: _, + end_token: _, operand, conditions, else_result, diff --git a/src/ast/spans.rs b/src/ast/spans.rs index dd918c34b..c53ff59d5 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1566,18 +1566,24 @@ impl Spanned for Expr { ), Expr::Prefixed { value, .. } => value.span(), Expr::Case { + case_token, + end_token, operand, conditions, else_result, } => union_spans( - operand - .as_ref() - .map(|i| i.span()) - .into_iter() - .chain(conditions.iter().flat_map(|case_when| { - [case_when.condition.span(), case_when.result.span()] - })) - .chain(else_result.as_ref().map(|i| i.span())), + iter::once(case_token.0.span) + .chain( + operand + .as_ref() + .map(|i| i.span()) + .into_iter() + .chain(conditions.iter().flat_map(|case_when| { + [case_when.condition.span(), case_when.result.span()] + })) + .chain(else_result.as_ref().map(|i| i.span())), + ) + .chain(iter::once(end_token.0.span)), ), Expr::Exists { subquery, .. } => subquery.span(), Expr::Subquery(query) => query.span(), diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 3e721072b..18324f3b3 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2274,6 +2274,7 @@ impl<'a> Parser<'a> { } pub fn parse_case_expr(&mut self) -> Result { + let case_token = AttachedToken(self.get_previous_token().clone()); let mut operand = None; if !self.parse_keyword(Keyword::WHEN) { operand = Some(Box::new(self.parse_expr()?)); @@ -2294,8 +2295,10 @@ impl<'a> Parser<'a> { } else { None }; - self.expect_keyword_is(Keyword::END)?; + let end_token = AttachedToken(self.expect_keyword(Keyword::END)?); Ok(Expr::Case { + case_token, + end_token, operand, conditions, else_result, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a1a8fc3b3..a7f6af294 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6861,6 +6861,8 @@ fn parse_searched_case_expr() { let select = verified_only_select(sql); assert_eq!( &Case { + case_token: AttachedToken::empty(), + end_token: AttachedToken::empty(), operand: None, conditions: vec![ CaseWhen { @@ -6900,6 +6902,8 @@ fn parse_simple_case_expr() { use self::Expr::{Case, Identifier}; assert_eq!( &Case { + case_token: AttachedToken::empty(), + end_token: AttachedToken::empty(), operand: Some(Box::new(Identifier(Ident::new("foo")))), conditions: vec![CaseWhen { condition: Expr::value(number("1")), @@ -14464,6 +14468,16 @@ fn test_case_statement_span() { ); } +#[test] +fn test_case_expr_span() { + let sql = "CASE 1 WHEN 2 THEN 3 ELSE 4 END"; + let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); + assert_eq!( + parser.parse_expr().unwrap().span(), + Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1)) + ); +} + #[test] fn parse_if_statement() { let dialects = all_dialects_except(|d| d.is::()); @@ -14642,6 +14656,8 @@ fn test_lambdas() { Expr::Lambda(LambdaFunction { params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), body: Box::new(Expr::Case { + case_token: AttachedToken::empty(), + end_token: AttachedToken::empty(), operand: None, conditions: vec![ CaseWhen { diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index 88aae499a..99b7eecde 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::ast::*; use sqlparser::dialect::{DatabricksDialect, GenericDialect}; use sqlparser::parser::ParserError; @@ -108,6 +109,8 @@ fn test_databricks_lambdas() { Expr::Lambda(LambdaFunction { params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), body: Box::new(Expr::Case { + case_token: AttachedToken::empty(), + end_token: AttachedToken::empty(), operand: None, conditions: vec![ CaseWhen { From 6cf531d4951545708a82b208bed8ea596f91a186 Mon Sep 17 00:00:00 2001 From: Elia Perantoni Date: Thu, 5 Jun 2025 11:24:52 +0200 Subject: [PATCH 2/3] fix: incorrect usage of previous token --- src/parser/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 18324f3b3..6784abae6 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2274,7 +2274,7 @@ impl<'a> Parser<'a> { } pub fn parse_case_expr(&mut self) -> Result { - let case_token = AttachedToken(self.get_previous_token().clone()); + let case_token = AttachedToken(self.get_current_token().clone()); let mut operand = None; if !self.parse_keyword(Keyword::WHEN) { operand = Some(Box::new(self.parse_expr()?)); From d1d7afec0b64c3a046b1995edad2f814ee6f724c Mon Sep 17 00:00:00 2001 From: Elia Perantoni Date: Fri, 6 Jun 2025 14:46:31 +0200 Subject: [PATCH 3/3] chore: move test --- src/ast/spans.rs | 12 ++++++++++++ tests/sqlparser_common.rs | 10 ---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/ast/spans.rs b/src/ast/spans.rs index c53ff59d5..db928ecef 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -2469,4 +2469,16 @@ pub mod tests { assert_eq!(test.get_source(body_span), "SELECT cte.* FROM cte"); } + + #[test] + fn test_case_expr_span() { + let dialect = &GenericDialect; + let mut test = SpanTest::new(dialect, "CASE 1 WHEN 2 THEN 3 ELSE 4 END"); + let expr = test.0.parse_expr().unwrap(); + let expr_span = expr.span(); + assert_eq!( + test.get_source(expr_span), + "CASE 1 WHEN 2 THEN 3 ELSE 4 END" + ); + } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a7f6af294..5105f644f 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14468,16 +14468,6 @@ fn test_case_statement_span() { ); } -#[test] -fn test_case_expr_span() { - let sql = "CASE 1 WHEN 2 THEN 3 ELSE 4 END"; - let mut parser = Parser::new(&GenericDialect {}).try_with_sql(sql).unwrap(); - assert_eq!( - parser.parse_expr().unwrap().span(), - Span::new(Location::new(1, 1), Location::new(1, sql.len() as u64 + 1)) - ); -} - #[test] fn parse_if_statement() { let dialects = all_dialects_except(|d| d.is::());