Skip to content

Commit 710cc21

Browse files
authored
fix(compiler): Add validation for GROUP BY clause column references (#1285)
* fix(compiler): Add validation for GROUP BY clause column references * Add MySQL test
1 parent 466c3e1 commit 710cc21

File tree

8 files changed

+125
-3
lines changed

8 files changed

+125
-3
lines changed

internal/compiler/output_columns.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
6464
targets = n.ReturningList
6565
case *ast.SelectStmt:
6666
targets = n.TargetList
67+
68+
if n.GroupClause != nil {
69+
for _, item := range n.GroupClause.Items {
70+
ref, ok := item.(*ast.ColumnRef)
71+
if !ok {
72+
continue
73+
}
74+
75+
if err := findColumnForRef(ref, tables); err != nil {
76+
return nil, err
77+
}
78+
}
79+
}
80+
6781
// For UNION queries, targets is empty and we need to look for the
6882
// columns in Largs.
6983
if len(targets.Items) == 0 && n.Larg != nil {
@@ -470,3 +484,43 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
470484
}
471485
return cols, nil
472486
}
487+
488+
func findColumnForRef(ref *ast.ColumnRef, tables []*Table) error {
489+
parts := stringSlice(ref.Fields)
490+
var alias, name string
491+
if len(parts) == 1 {
492+
name = parts[0]
493+
} else if len(parts) == 2 {
494+
alias = parts[0]
495+
name = parts[1]
496+
}
497+
498+
var found int
499+
for _, t := range tables {
500+
if alias != "" && t.Rel.Name != alias {
501+
continue
502+
}
503+
for _, c := range t.Columns {
504+
if c.Name == name {
505+
found++
506+
}
507+
}
508+
}
509+
510+
if found == 0 {
511+
return &sqlerr.Error{
512+
Code: "42703",
513+
Message: fmt.Sprintf("column reference \"%s\" not found", name),
514+
Location: ref.Location,
515+
}
516+
}
517+
if found > 1 {
518+
return &sqlerr.Error{
519+
Code: "42703",
520+
Message: fmt.Sprintf("column reference \"%s\" is ambiguous", name),
521+
Location: ref.Location,
522+
}
523+
}
524+
525+
return nil
526+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
CREATE TABLE authors (
2+
id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
3+
name text NOT NULL,
4+
bio text,
5+
UNIQUE(name)
6+
);
7+
8+
-- name: ListAuthors :many
9+
SELECT *
10+
FROM authors
11+
GROUP BY invalid_reference;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"engine": "mysql",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:9:1: column reference "invalid_reference" not found
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CREATE TABLE authors (
2+
id BIGSERIAL PRIMARY KEY,
3+
name text NOT NULL,
4+
bio text
5+
);
6+
7+
-- name: ListAuthors :many
8+
SELECT *
9+
FROM authors
10+
GROUP BY invalid_reference;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"path": "go",
6+
"engine": "postgresql",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:10:10: column reference "invalid_reference" not found

internal/engine/dolphin/convert.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt {
474474
stmt := &ast.SelectStmt{
475475
TargetList: c.convertFieldList(n.Fields),
476476
FromClause: c.convertTableRefsClause(n.From),
477+
GroupClause: c.convertGroupByClause(n.GroupBy),
477478
WhereClause: c.convert(n.Where),
478479
WithClause: c.convertWithClause(n.With),
479480
WindowClause: windowClause,
@@ -677,7 +678,14 @@ func (c *cc) convertBinlogStmt(n *pcast.BinlogStmt) ast.Node {
677678
}
678679

679680
func (c *cc) convertByItem(n *pcast.ByItem) ast.Node {
680-
return todo(n)
681+
switch n.Expr.(type) {
682+
case *pcast.PositionExpr:
683+
return c.convertPositionExpr(n.Expr.(*pcast.PositionExpr))
684+
case *pcast.ColumnNameExpr:
685+
return c.convertColumnNameExpr(n.Expr.(*pcast.ColumnNameExpr))
686+
default:
687+
return todo(n)
688+
}
681689
}
682690

683691
func (c *cc) convertCaseExpr(n *pcast.CaseExpr) ast.Node {
@@ -858,8 +866,19 @@ func (c *cc) convertGrantStmt(n *pcast.GrantStmt) ast.Node {
858866
return todo(n)
859867
}
860868

861-
func (c *cc) convertGroupByClause(n *pcast.GroupByClause) ast.Node {
862-
return todo(n)
869+
func (c *cc) convertGroupByClause(n *pcast.GroupByClause) *ast.List {
870+
if n == nil {
871+
return &ast.List{}
872+
}
873+
874+
var items []ast.Node
875+
for _, item := range n.Items {
876+
items = append(items, c.convertByItem(item))
877+
}
878+
879+
return &ast.List{
880+
Items: items,
881+
}
863882
}
864883

865884
func (c *cc) convertHavingClause(n *pcast.HavingClause) ast.Node {

0 commit comments

Comments
 (0)