Skip to content

Commit 72ced4b

Browse files
jamiibenesch
authored andcommitted
Support COUNT(DISTINCT x) and similar
1 parent 4f944dd commit 72ced4b

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

src/sqlast/mod.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ pub enum ASTNode {
112112
name: SQLObjectName,
113113
args: Vec<ASTNode>,
114114
over: Option<SQLWindowSpec>,
115+
// aggregate functions may specify eg `COUNT(DISTINCT x)`
116+
distinct: bool,
115117
},
116118
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
117119
/// Note we only recognize a complete single expression as <condition>, not
@@ -190,8 +192,18 @@ impl ToString for ASTNode {
190192
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
191193
}
192194
ASTNode::SQLValue(v) => v.to_string(),
193-
ASTNode::SQLFunction { name, args, over } => {
194-
let mut s = format!("{}({})", name.to_string(), comma_separated_string(args));
195+
ASTNode::SQLFunction {
196+
name,
197+
args,
198+
over,
199+
distinct,
200+
} => {
201+
let mut s = format!(
202+
"{}({}{})",
203+
name.to_string(),
204+
if *distinct { "DISTINCT " } else { "" },
205+
comma_separated_string(args)
206+
);
195207
if let Some(o) = over {
196208
s += &format!(" OVER ({})", o.to_string())
197209
}

src/sqlparser.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ impl Parser {
253253

254254
pub fn parse_function(&mut self, name: SQLObjectName) -> Result<ASTNode, ParserError> {
255255
self.expect_token(&Token::LParen)?;
256+
let all = self.parse_keyword("ALL");
257+
let distinct = self.parse_keyword("DISTINCT");
258+
if all && distinct {
259+
return parser_err!(format!(
260+
"Cannot specify both ALL and DISTINCT in function: {}",
261+
name.to_string(),
262+
));
263+
}
256264
let args = self.parse_optional_args()?;
257265
let over = if self.parse_keyword("OVER") {
258266
// TBD: support window names (`OVER mywin`) in place of inline specification
@@ -279,7 +287,12 @@ impl Parser {
279287
None
280288
};
281289

282-
Ok(ASTNode::SQLFunction { name, args, over })
290+
Ok(ASTNode::SQLFunction {
291+
name,
292+
args,
293+
over,
294+
distinct,
295+
})
283296
}
284297

285298
pub fn parse_window_frame(&mut self) -> Result<Option<SQLWindowFrame>, ParserError> {

tests/sqlparser_common.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,44 @@ fn parse_select_count_wildcard() {
197197
name: SQLObjectName(vec!["COUNT".to_string()]),
198198
args: vec![ASTNode::SQLWildcard],
199199
over: None,
200+
distinct: false,
200201
},
201202
expr_from_projection(only(&select.projection))
202203
);
203204
}
204205

206+
#[test]
207+
fn parse_select_count_distinct() {
208+
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
209+
let select = verified_only_select(sql);
210+
assert_eq!(
211+
&ASTNode::SQLFunction {
212+
name: SQLObjectName(vec!["COUNT".to_string()]),
213+
args: vec![ASTNode::SQLUnary {
214+
operator: SQLOperator::Plus,
215+
expr: Box::new(ASTNode::SQLIdentifier("x".to_string()))
216+
}],
217+
over: None,
218+
distinct: true,
219+
},
220+
expr_from_projection(only(&select.projection))
221+
);
222+
223+
one_statement_parses_to(
224+
"SELECT COUNT(ALL + x) FROM customer",
225+
"SELECT COUNT(+ x) FROM customer",
226+
);
227+
228+
let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
229+
let res = parse_sql_statements(sql);
230+
assert_eq!(
231+
ParserError::ParserError(
232+
"Cannot specify both ALL and DISTINCT in function: COUNT".to_string()
233+
),
234+
res.unwrap_err()
235+
);
236+
}
237+
205238
#[test]
206239
fn parse_not() {
207240
let sql = "SELECT id FROM customer WHERE NOT salary = ''";
@@ -662,6 +695,7 @@ fn parse_scalar_function_in_projection() {
662695
name: SQLObjectName(vec!["sqrt".to_string()]),
663696
args: vec![ASTNode::SQLIdentifier("id".to_string())],
664697
over: None,
698+
distinct: false,
665699
},
666700
expr_from_projection(only(&select.projection))
667701
);
@@ -690,7 +724,8 @@ fn parse_window_functions() {
690724
asc: Some(false)
691725
}],
692726
window_frame: None,
693-
})
727+
}),
728+
distinct: false,
694729
},
695730
expr_from_projection(&select.projection[0])
696731
);
@@ -762,6 +797,7 @@ fn parse_delimited_identifiers() {
762797
name: SQLObjectName(vec![r#""myfun""#.to_string()]),
763798
args: vec![],
764799
over: None,
800+
distinct: false,
765801
},
766802
expr_from_projection(&select.projection[1]),
767803
);

0 commit comments

Comments
 (0)