Skip to content

Commit e3e2fcb

Browse files
committed
Pick correct source tables for star expansion
1 parent 77de3bc commit e3e2fcb

File tree

2 files changed

+85
-19
lines changed

2 files changed

+85
-19
lines changed

internal/dinosql/parser.go

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
420420
if err != nil {
421421
return nil, err
422422
}
423-
424-
// TODO: Limit calls to sourceTables
425-
tables, err := sourceTables(c, raw.Stmt)
426-
if err != nil {
427-
return nil, err
428-
}
429-
expanded, err := expand(raw, tables, rawSQL)
423+
expanded, err := expand(c, raw, rawSQL)
430424
if err != nil {
431425
return nil, err
432426
}
@@ -468,25 +462,65 @@ type edit struct {
468462
New string
469463
}
470464

471-
func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error) {
465+
func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
472466
list := search(raw, func(node nodes.Node) bool {
473-
res, ok := node.(nodes.ResTarget)
474-
if !ok {
475-
return false
476-
}
477-
ref, ok := res.Val.(nodes.ColumnRef)
478-
if !ok {
467+
switch node.(type) {
468+
case nodes.DeleteStmt:
469+
case nodes.InsertStmt:
470+
case nodes.SelectStmt:
471+
case nodes.UpdateStmt:
472+
default:
479473
return false
480474
}
481-
return HasStarRef(ref)
475+
return true
482476
})
483477
if len(list.Items) == 0 {
484478
return sql, nil
485479
}
486480
var edits []edit
487481
for _, item := range list.Items {
488-
res := item.(nodes.ResTarget)
489-
ref := res.Val.(nodes.ColumnRef)
482+
edit, err := expandStmt(c, raw, item)
483+
if err != nil {
484+
return "", err
485+
}
486+
edits = append(edits, edit...)
487+
}
488+
return editQuery(sql, edits)
489+
}
490+
491+
func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
492+
tables, err := sourceTables(c, node)
493+
if err != nil {
494+
return nil, err
495+
}
496+
497+
var targets nodes.List
498+
switch n := node.(type) {
499+
case nodes.DeleteStmt:
500+
targets = n.ReturningList
501+
case nodes.InsertStmt:
502+
targets = n.ReturningList
503+
case nodes.SelectStmt:
504+
targets = n.TargetList
505+
case nodes.UpdateStmt:
506+
targets = n.ReturningList
507+
default:
508+
return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
509+
}
510+
511+
var edits []edit
512+
for _, target := range targets.Items {
513+
res, ok := target.(nodes.ResTarget)
514+
if !ok {
515+
continue
516+
}
517+
ref, ok := res.Val.(nodes.ColumnRef)
518+
if !ok {
519+
continue
520+
}
521+
if !HasStarRef(ref) {
522+
continue
523+
}
490524
var parts, cols []string
491525
for _, f := range ref.Fields.Items {
492526
switch field := f.(type) {
@@ -495,7 +529,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
495529
case nodes.A_Star:
496530
parts = append(parts, "*")
497531
default:
498-
return "", fmt.Errorf("unknown field in ColumnRef: %T", f)
532+
return nil, fmt.Errorf("unknown field in ColumnRef: %T", f)
499533
}
500534
}
501535
for _, t := range tables {
@@ -520,7 +554,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
520554
New: strings.Join(cols, ", "),
521555
})
522556
}
523-
return editQuery(sql, edits)
557+
return edits, nil
524558
}
525559

526560
func editQuery(raw string, a []edit) (string, error) {

internal/dinosql/query_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,38 @@ func TestQueries(t *testing.T) {
781781
SQL: "WITH cte AS (SELECT a, b FROM foo) SELECT c, d FROM bar",
782782
},
783783
},
784+
{
785+
"star-expansion-from-cte",
786+
`
787+
CREATE TABLE foo (a text, b text);
788+
CREATE TABLE bar (c text, d text);
789+
WITH cte AS (SELECT * FROM foo) SELECT * FROM cte;
790+
`,
791+
Query{
792+
Columns: []core.Column{
793+
{Name: "a", DataType: "text"},
794+
{Name: "b", DataType: "text"},
795+
},
796+
SQL: "WITH cte AS (SELECT a, b FROM foo) SELECT a, b FROM cte",
797+
},
798+
},
799+
{
800+
"star-expansion-join",
801+
`
802+
CREATE TABLE foo (a text, b text);
803+
CREATE TABLE bar (c text, d text);
804+
SELECT * FROM foo, bar;
805+
`,
806+
Query{
807+
Columns: []core.Column{
808+
{Name: "a", DataType: "text", Table: public("foo")},
809+
{Name: "b", DataType: "text", Table: public("foo")},
810+
{Name: "c", DataType: "text", Table: public("bar")},
811+
{Name: "d", DataType: "text", Table: public("bar")},
812+
},
813+
SQL: "SELECT a, b, c, d FROM foo, bar",
814+
},
815+
},
784816
} {
785817
test := tc
786818
t.Run(test.name, func(t *testing.T) {

0 commit comments

Comments
 (0)