Skip to content

Commit 751b01f

Browse files
committed
internal/dinosql: Implement robust expansion
Exapnds `SELECT *` into the correct columns, no matter the location
1 parent b393d65 commit 751b01f

File tree

3 files changed

+173
-88
lines changed

3 files changed

+173
-88
lines changed

internal/dinosql/parser.go

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func lineno(source string, head int) (int, int) {
289289
func pluckQuery(source string, n nodes.RawStmt) (string, error) {
290290
head := n.StmtLocation
291291
tail := n.StmtLocation + n.StmtLen
292-
return strings.TrimSpace(source[head:tail]), nil
292+
return source[head:tail], nil
293293
}
294294

295295
func rangeVars(root nodes.Node) []nodes.RangeVar {
@@ -403,7 +403,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
403403
if err := validateFuncCall(&c, raw); err != nil {
404404
return nil, err
405405
}
406-
name, cmd, err := parseMetadata(rawSQL)
406+
name, cmd, err := parseMetadata(strings.TrimSpace(rawSQL))
407407
if err != nil {
408408
return nil, err
409409
}
@@ -422,19 +422,28 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
422422
return nil, err
423423
}
424424

425-
trimmed, comments, err := stripComments(rawSQL)
425+
// TODO: Limit calls to sourceTables
426+
tables, err := sourceTables(c, raw.Stmt)
427+
if err != nil {
428+
return nil, err
429+
}
430+
expanded, err := expand(raw, tables, rawSQL)
431+
if err != nil {
432+
return nil, err
433+
}
434+
435+
trimmed, comments, err := stripComments(strings.TrimSpace(expanded))
426436
if err != nil {
427437
return nil, err
428438
}
429439

430440
return &Query{
431-
Cmd: cmd,
432-
Comments: comments,
433-
Name: name,
434-
Params: params,
435-
Columns: cols,
436-
SQL: trimmed,
437-
NeedsEdit: needsEdit(stmt),
441+
Cmd: cmd,
442+
Comments: comments,
443+
Name: name,
444+
Params: params,
445+
Columns: cols,
446+
SQL: trimmed,
438447
}, nil
439448
}
440449

@@ -454,6 +463,86 @@ func stripComments(sql string) (string, []string, error) {
454463
return strings.Join(lines, "\n"), comments, s.Err()
455464
}
456465

466+
type edit struct {
467+
Location int
468+
Old string
469+
New string
470+
}
471+
472+
func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error) {
473+
list := search(raw, func(node nodes.Node) bool {
474+
res, ok := node.(nodes.ResTarget)
475+
if !ok {
476+
return false
477+
}
478+
ref, ok := res.Val.(nodes.ColumnRef)
479+
if !ok {
480+
return false
481+
}
482+
return HasStarRef(ref)
483+
})
484+
if len(list.Items) == 0 {
485+
return sql, nil
486+
}
487+
var edits []edit
488+
for _, item := range list.Items {
489+
res := item.(nodes.ResTarget)
490+
ref := res.Val.(nodes.ColumnRef)
491+
var parts, cols []string
492+
for _, f := range ref.Fields.Items {
493+
switch field := f.(type) {
494+
case nodes.String:
495+
parts = append(parts, field.Str)
496+
case nodes.A_Star:
497+
parts = append(parts, "*")
498+
default:
499+
return "", fmt.Errorf("unknown field in ColumnRef: %T", f)
500+
}
501+
}
502+
for _, t := range tables {
503+
scope := join(ref.Fields, ".")
504+
if scope != "" && scope != t.Name {
505+
continue
506+
}
507+
for _, c := range t.Columns {
508+
cname := c.Name
509+
if res.Name != nil {
510+
cname = *res.Name
511+
}
512+
if scope != "" {
513+
cname = scope + "." + cname
514+
}
515+
cols = append(cols, cname)
516+
}
517+
}
518+
edits = append(edits, edit{
519+
Location: res.Location - raw.StmtLocation,
520+
Old: strings.Join(parts, "."),
521+
New: strings.Join(cols, ", "),
522+
})
523+
}
524+
return editQuery(sql, edits)
525+
}
526+
527+
func editQuery(raw string, a []edit) (string, error) {
528+
sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location })
529+
// TODO: Check bounds
530+
s := raw
531+
for _, edit := range a {
532+
// fmt.Printf("edit q %q\n", s)
533+
// fmt.Printf("edit e %q\n", edit)
534+
start := edit.Location
535+
stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty
536+
if stop < len(s) {
537+
s = s[:start] + edit.New + s[stop+1:]
538+
} else {
539+
s = s[:start] + edit.New
540+
}
541+
}
542+
// fmt.Printf("edit fixed %q\n", s)
543+
return s, nil
544+
}
545+
457546
type QueryCatalog struct {
458547
catalog core.Catalog
459548
ctes map[string]core.Table
@@ -653,6 +742,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {
653742

654743
case nodes.ColumnRef:
655744
if HasStarRef(n) {
745+
// TODO: This code is copied in func expand()
656746
for _, t := range tables {
657747
scope := join(n.Fields, ".")
658748
if scope != "" && scope != t.Name {
@@ -916,24 +1006,6 @@ func findParameters(root nodes.Node) []paramRef {
9161006
return refs
9171007
}
9181008

919-
type starWalker struct {
920-
found bool
921-
}
922-
923-
func (s *starWalker) Visit(node nodes.Node) Visitor {
924-
if _, ok := node.(nodes.A_Star); ok {
925-
s.found = true
926-
return nil
927-
}
928-
return s
929-
}
930-
931-
func needsEdit(root nodes.Node) bool {
932-
v := &starWalker{}
933-
Walk(v, root)
934-
return v.found
935-
}
936-
9371009
type nodeSearch struct {
9381010
list nodes.List
9391011
check func(nodes.Node) bool

internal/dinosql/parser_test.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ func TestPluck(t *testing.T) {
2626
}
2727

2828
expected := []string{
29-
"SELECT * FROM venue WHERE slug = $1 AND city = $2",
30-
"SELECT * FROM venue WHERE slug = $1",
31-
"SELECT * FROM venue LIMIT $1",
32-
"SELECT * FROM venue OFFSET $1",
29+
"\nSELECT * FROM venue WHERE slug = $1 AND city = $2",
30+
"\nSELECT * FROM venue WHERE slug = $1",
31+
"\nSELECT * FROM venue LIMIT $1",
32+
"\nSELECT * FROM venue OFFSET $1",
3333
}
3434

3535
for i, stmt := range tree.Statements {
@@ -220,3 +220,21 @@ func TestParseMetadata(t *testing.T) {
220220
}
221221
}
222222
}
223+
224+
func TestExpand(t *testing.T) {
225+
// pretend that foo has two columns, a and b
226+
raw := `SELECT *, *, foo.* FROM foo`
227+
expected := `SELECT a, b, a, b, foo.a, foo.b FROM foo`
228+
edits := []edit{
229+
{7, "*", "a, b"},
230+
{10, "*", "a, b"},
231+
{13, "foo.*", "foo.a, foo.b"},
232+
}
233+
actual, err := editQuery(raw, edits)
234+
if err != nil {
235+
t.Error(err)
236+
}
237+
if expected != actual {
238+
t.Errorf("mismatch:\nexpected: %s\n acutal: %s", expected, actual)
239+
}
240+
}

internal/dinosql/query_test.go

Lines changed: 51 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
core "github.com/kyleconroy/sqlc/internal/pg"
99

10-
"github.com/davecgh/go-spew/spew"
1110
"github.com/google/go-cmp/cmp"
1211
pg "github.com/lfittl/pg_query_go"
1312
)
@@ -26,8 +25,6 @@ func parseSQL(in string) (Query, error) {
2625
if q == nil {
2726
return Query{}, err
2827
}
29-
q.SQL = ""
30-
q.NeedsEdit = false
3128
return *q, err
3229
}
3330

@@ -737,13 +734,63 @@ func TestQueries(t *testing.T) {
737734
},
738735
},
739736
},
737+
{
738+
"star-expansion",
739+
`
740+
CREATE TABLE foo (a text, b text);
741+
SELECT *, *, foo.* FROM foo;
742+
`,
743+
Query{
744+
Columns: []core.Column{
745+
{Name: "a", DataType: "text", Table: public("foo")},
746+
{Name: "b", DataType: "text", Table: public("foo")},
747+
{Name: "a", DataType: "text", Table: public("foo")},
748+
{Name: "b", DataType: "text", Table: public("foo")},
749+
{Name: "a", DataType: "text", Scope: "foo", Table: public("foo")},
750+
{Name: "b", DataType: "text", Scope: "foo", Table: public("foo")},
751+
},
752+
SQL: "SELECT a, b, a, b, foo.a, foo.b FROM foo",
753+
},
754+
},
755+
{
756+
"star-expansion-subquery",
757+
`
758+
CREATE TABLE foo (a text, b text);
759+
SELECT * FROM foo WHERE EXISTS (SELECT * FROM foo);
760+
`,
761+
Query{
762+
Columns: []core.Column{
763+
{Name: "a", DataType: "text", Table: public("foo")},
764+
{Name: "b", DataType: "text", Table: public("foo")},
765+
},
766+
SQL: "SELECT a, b FROM foo WHERE EXISTS (SELECT a, b FROM foo)",
767+
},
768+
},
769+
{
770+
"star-expansion-cte",
771+
`
772+
CREATE TABLE foo (a text, b text);
773+
CREATE TABLE bar (c text, d text);
774+
WITH cte AS (SELECT * FROM foo) SELECT * FROM bar;
775+
`,
776+
Query{
777+
Columns: []core.Column{
778+
{Name: "c", DataType: "text", Table: public("bar")},
779+
{Name: "d", DataType: "text", Table: public("bar")},
780+
},
781+
SQL: "WITH cte AS (SELECT a, b FROM foo) SELECT c, d FROM bar",
782+
},
783+
},
740784
} {
741785
test := tc
742786
t.Run(test.name, func(t *testing.T) {
743787
q, err := parseSQL(test.stmt)
744788
if err != nil {
745789
t.Fatal(err)
746790
}
791+
if test.query.SQL == "" {
792+
q.SQL = ""
793+
}
747794
if diff := cmp.Diff(test.query, q); diff != "" {
748795
t.Errorf("query mismatch: \n%s", diff)
749796
}
@@ -765,6 +812,7 @@ func TestComparisonOperators(t *testing.T) {
765812
t.Fatal(err)
766813
}
767814
expected := Query{
815+
SQL: q.SQL,
768816
Columns: []core.Column{
769817
{Name: "", DataType: "bool", NotNull: true},
770818
},
@@ -776,59 +824,6 @@ func TestComparisonOperators(t *testing.T) {
776824
}
777825
}
778826

779-
func TestStarWalker(t *testing.T) {
780-
for i, tc := range []struct {
781-
stmt string
782-
expected bool
783-
}{
784-
{
785-
`
786-
SELECT * FROM city ORDER BY name;
787-
`,
788-
true,
789-
},
790-
{
791-
`
792-
INSERT INTO city (
793-
name,
794-
slug
795-
) VALUES (
796-
$1,
797-
$2
798-
) RETURNING *;
799-
`,
800-
true,
801-
},
802-
{
803-
`
804-
UPDATE city SET name = $2 WHERE slug = $1;
805-
`,
806-
false,
807-
},
808-
{
809-
`
810-
UPDATE venue
811-
SET name = $2
812-
WHERE slug = $1
813-
RETURNING *;
814-
`,
815-
true,
816-
},
817-
} {
818-
test := tc
819-
t.Run(strconv.Itoa(i), func(t *testing.T) {
820-
tree, err := pg.Parse(test.stmt)
821-
if err != nil {
822-
t.Fatal(err)
823-
}
824-
if diff := cmp.Diff(test.expected, needsEdit(tree.Statements[0])); diff != "" {
825-
spew.Dump(tree.Statements[0])
826-
t.Errorf("query mismatch: \n%s", diff)
827-
}
828-
})
829-
}
830-
}
831-
832827
func TestInvalidQueries(t *testing.T) {
833828
for i, tc := range []struct {
834829
stmt string

0 commit comments

Comments
 (0)