Skip to content

Commit 2b395e8

Browse files
committed
Merge branch 'count-distinct' into visitor
2 parents a6e9efa + cf2c20b commit 2b395e8

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

src/sqlast/mod.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ pub enum ASTNode {
112112
name: SQLObjectName,
113113
args: Vec<ASTNode>,
114114
over: Option<SQLWindowSpec>,
115+
// aggregate functions may specify eg `COUNT(DISTINCT x)`
116+
all: bool,
117+
distinct: bool,
115118
},
116119
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
117120
/// Note we only recognize a complete single expression as <condition>, not
@@ -190,8 +193,20 @@ impl ToString for ASTNode {
190193
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
191194
}
192195
ASTNode::SQLValue(v) => v.to_string(),
193-
ASTNode::SQLFunction { name, args, over } => {
194-
let mut s = format!("{}({})", name.to_string(), comma_separated_string(args));
196+
ASTNode::SQLFunction {
197+
name,
198+
args,
199+
over,
200+
all,
201+
distinct,
202+
} => {
203+
let mut s = format!(
204+
"{}({}{}{})",
205+
name.to_string(),
206+
if *all { "ALL " } else { "" },
207+
if *distinct { "DISTINCT " } else { "" },
208+
comma_separated_string(args)
209+
);
195210
if let Some(o) = over {
196211
s += &format!(" OVER ({})", o.to_string())
197212
}

src/sqlparser.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,14 @@ impl Parser {
266266

267267
pub fn parse_function(&mut self, name: SQLObjectName) -> Result<ASTNode, ParserError> {
268268
self.expect_token(&Token::LParen)?;
269+
let all = self.parse_keyword("ALL");
270+
let distinct = self.parse_keyword("DISTINCT");
271+
if all && distinct {
272+
return parser_err!(format!(
273+
"Cannot specify both ALL and DISTINCT in function: {:?}",
274+
name
275+
));
276+
}
269277
let args = self.parse_optional_args()?;
270278
let over = if self.parse_keyword("OVER") {
271279
// TBD: support window names (`OVER mywin`) in place of inline specification
@@ -292,7 +300,13 @@ impl Parser {
292300
None
293301
};
294302

295-
Ok(ASTNode::SQLFunction { name, args, over })
303+
Ok(ASTNode::SQLFunction {
304+
name,
305+
args,
306+
over,
307+
all,
308+
distinct,
309+
})
296310
}
297311

298312
pub fn parse_window_frame(&mut self) -> Result<Option<SQLWindowFrame>, ParserError> {

tests/sqlparser_common.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,27 @@ fn parse_select_count_wildcard() {
217217
name: SQLObjectName(vec!["COUNT".to_string()]),
218218
args: vec![ASTNode::SQLWildcard],
219219
over: None,
220+
distinct: false,
221+
all: false,
222+
},
223+
expr_from_projection(only(&select.projection))
224+
);
225+
}
226+
227+
#[test]
228+
fn parse_select_count_distinct() {
229+
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
230+
let select = verified_only_select(sql);
231+
assert_eq!(
232+
&ASTNode::SQLFunction {
233+
name: SQLObjectName(vec!["COUNT".to_string()]),
234+
args: vec![ASTNode::SQLUnary {
235+
operator: SQLOperator::Plus,
236+
expr: Box::new(ASTNode::SQLIdentifier("x".to_string()))
237+
}],
238+
over: None,
239+
distinct: true,
240+
all: false,
220241
},
221242
expr_from_projection(only(&select.projection))
222243
);
@@ -704,6 +725,8 @@ fn parse_scalar_function_in_projection() {
704725
name: SQLObjectName(vec!["sqrt".to_string()]),
705726
args: vec![ASTNode::SQLIdentifier("id".to_string())],
706727
over: None,
728+
all: false,
729+
distinct: false,
707730
},
708731
expr_from_projection(only(&select.projection))
709732
);
@@ -732,7 +755,9 @@ fn parse_window_functions() {
732755
asc: Some(false)
733756
}],
734757
window_frame: None,
735-
})
758+
}),
759+
all: false,
760+
distinct: false,
736761
},
737762
expr_from_projection(&select.projection[0])
738763
);
@@ -804,6 +829,8 @@ fn parse_delimited_identifiers() {
804829
name: SQLObjectName(vec![r#""myfun""#.to_string()]),
805830
args: vec![],
806831
over: None,
832+
all: false,
833+
distinct: false,
807834
},
808835
expr_from_projection(&select.projection[1]),
809836
);

0 commit comments

Comments
 (0)