Skip to content

Commit 7d6377f

Browse files
fix(compiler): correctly validate alias in order/group by clauses for joins (#2537)
* fix(compiler): correctly validate alias in order/group by clauses for joins Resolves #1886 Resolves #2398 Resolves #2399 * remove dead code and split up test
1 parent fec8949 commit 7d6377f

File tree

11 files changed

+209
-14
lines changed

11 files changed

+209
-14
lines changed

internal/compiler/output_columns.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
6969

7070
if n.GroupClause != nil {
7171
for _, item := range n.GroupClause.Items {
72-
if err := findColumnForNode(item, tables, n); err != nil {
72+
if err := findColumnForNode(item, tables, targets); err != nil {
7373
return nil, err
7474
}
7575
}
@@ -85,7 +85,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
8585
if !ok {
8686
continue
8787
}
88-
if err := findColumnForNode(sb.Node, tables, n); err != nil {
88+
if err := findColumnForNode(sb.Node, tables, targets); err != nil {
8989
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
9090
}
9191
}
@@ -101,7 +101,7 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
101101
if !ok {
102102
continue
103103
}
104-
if err := findColumnForNode(caseExpr.Xpr, tables, n); err != nil {
104+
if err := findColumnForNode(caseExpr.Xpr, tables, targets); err != nil {
105105
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
106106
}
107107
}
@@ -650,15 +650,15 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
650650
return cols, nil
651651
}
652652

653-
func findColumnForNode(item ast.Node, tables []*Table, n *ast.SelectStmt) error {
653+
func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) error {
654654
ref, ok := item.(*ast.ColumnRef)
655655
if !ok {
656656
return nil
657657
}
658-
return findColumnForRef(ref, tables, n)
658+
return findColumnForRef(ref, tables, targetList)
659659
}
660660

661-
func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.SelectStmt) error {
661+
func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) error {
662662
parts := stringSlice(ref.Fields)
663663
var alias, name string
664664
if len(parts) == 1 {
@@ -675,20 +675,17 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.
675675
}
676676

677677
// Find matching column
678-
var foundColumn bool
679678
for _, c := range t.Columns {
680679
if c.Name == name {
681680
found++
682-
foundColumn = true
681+
break
683682
}
684683
}
684+
}
685685

686-
if foundColumn {
687-
continue
688-
}
689-
690-
// Find matching alias
691-
for _, c := range selectStatement.TargetList.Items {
686+
// Find matching alias if necessary
687+
if found == 0 {
688+
for _, c := range targetList.Items {
692689
resTarget, ok := c.(*ast.ResTarget)
693690
if !ok {
694691
continue

internal/endtoend/testdata/join_group_by_alias/postgresql/stdlib/go/db.go

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/join_group_by_alias/postgresql/stdlib/go/models.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/join_group_by_alias/postgresql/stdlib/go/query.sql.go

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CREATE TABLE foo (email text not null);
2+
3+
-- name: ColumnAsGroupBy :many
4+
SELECT a.email AS id
5+
FROM foo a JOIN foo b ON a.email = b.email
6+
GROUP BY id;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"engine": "postgresql",
6+
"path": "go",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}

internal/endtoend/testdata/join_order_by_alias/postgresql/stdlib/go/db.go

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/join_order_by_alias/postgresql/stdlib/go/models.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/join_order_by_alias/postgresql/stdlib/go/query.sql.go

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CREATE TABLE foo (email text not null);
2+
3+
-- name: ColumnAsOrderBy :many
4+
SELECT a.email AS id
5+
FROM foo a JOIN foo b ON a.email = b.email
6+
ORDER BY id;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"engine": "postgresql",
6+
"path": "go",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}

0 commit comments

Comments
 (0)