diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index 19276e0126..2c7166d2a9 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -666,32 +666,13 @@ func (r Result) GoQueries() []GoQuery { continue } - code := query.SQL - - // TODO: Will horribly break sometimes - if query.NeedsEdit { - var cols []string - find := "*" - for _, c := range query.Columns { - if c.Scope != "" { - find = c.Scope + ".*" - } - name := c.Name - if c.Scope != "" { - name = c.Scope + "." + name - } - cols = append(cols, name) - } - code = strings.Replace(query.SQL, find, strings.Join(cols, ", "), 1) - } - gq := GoQuery{ Cmd: query.Cmd, ConstantName: lowerTitle(query.Name), FieldName: lowerTitle(query.Name) + "Stmt", MethodName: query.Name, SourceName: query.Filename, - SQL: code, + SQL: query.SQL, Comments: query.Comments, } diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 86df0f4b9c..d0e75532b3 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -167,8 +167,7 @@ type Query struct { Comments []string // XXX: Hack - NeedsEdit bool - Filename string + Filename string } type Result struct { @@ -289,7 +288,7 @@ func lineno(source string, head int) (int, int) { func pluckQuery(source string, n nodes.RawStmt) (string, error) { head := n.StmtLocation tail := n.StmtLocation + n.StmtLen - return strings.TrimSpace(source[head:tail]), nil + return source[head:tail], nil } func rangeVars(root nodes.Node) []nodes.RangeVar { @@ -403,7 +402,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err := validateFuncCall(&c, raw); err != nil { return nil, err } - name, cmd, err := parseMetadata(rawSQL) + name, cmd, err := parseMetadata(strings.TrimSpace(rawSQL)) if err != nil { return nil, err } @@ -421,20 +420,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } + expanded, err := expand(c, raw, rawSQL) + if err != nil { + return nil, err + } - trimmed, comments, err := stripComments(rawSQL) + trimmed, comments, err := stripComments(strings.TrimSpace(expanded)) if err != nil { return nil, err } return &Query{ - Cmd: cmd, - Comments: comments, - Name: name, - Params: params, - Columns: cols, - SQL: trimmed, - NeedsEdit: needsEdit(stmt), + Cmd: cmd, + Comments: comments, + Name: name, + Params: params, + Columns: cols, + SQL: trimmed, }, nil } @@ -454,6 +456,134 @@ func stripComments(sql string) (string, []string, error) { return strings.Join(lines, "\n"), comments, s.Err() } +type edit struct { + Location int + Old string + New string +} + +func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { + list := search(raw, func(node nodes.Node) bool { + switch node.(type) { + case nodes.DeleteStmt: + case nodes.InsertStmt: + case nodes.SelectStmt: + case nodes.UpdateStmt: + default: + return false + } + return true + }) + if len(list.Items) == 0 { + return sql, nil + } + var edits []edit + for _, item := range list.Items { + edit, err := expandStmt(c, raw, item) + if err != nil { + return "", err + } + edits = append(edits, edit...) + } + return editQuery(sql, edits) +} + +func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { + tables, err := sourceTables(c, node) + if err != nil { + return nil, err + } + + var targets nodes.List + switch n := node.(type) { + case nodes.DeleteStmt: + targets = n.ReturningList + case nodes.InsertStmt: + targets = n.ReturningList + case nodes.SelectStmt: + targets = n.TargetList + case nodes.UpdateStmt: + targets = n.ReturningList + default: + return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n) + } + + var edits []edit + for _, target := range targets.Items { + res, ok := target.(nodes.ResTarget) + if !ok { + continue + } + ref, ok := res.Val.(nodes.ColumnRef) + if !ok { + continue + } + if !HasStarRef(ref) { + continue + } + var parts, cols []string + for _, f := range ref.Fields.Items { + switch field := f.(type) { + case nodes.String: + parts = append(parts, field.Str) + case nodes.A_Star: + parts = append(parts, "*") + default: + return nil, fmt.Errorf("unknown field in ColumnRef: %T", f) + } + } + for _, t := range tables { + scope := join(ref.Fields, ".") + if scope != "" && scope != t.Name { + continue + } + for _, c := range t.Columns { + cname := c.Name + if res.Name != nil { + cname = *res.Name + } + if scope != "" { + cname = scope + "." + cname + } + cols = append(cols, cname) + } + } + edits = append(edits, edit{ + Location: res.Location - raw.StmtLocation, + Old: strings.Join(parts, "."), + New: strings.Join(cols, ", "), + }) + } + return edits, nil +} + +func editQuery(raw string, a []edit) (string, error) { + if len(a) == 0 { + return raw, nil + } + sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location }) + s := raw + for _, edit := range a { + start := edit.Location + if start > len(s) { + return "", fmt.Errorf("edit start location is out of bounds") + } + if len(edit.New) <= 0 { + return "", fmt.Errorf("empty edit contents") + } + if len(edit.Old) <= 0 { + return "", fmt.Errorf("empty edit contents") + } + stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty + if stop < len(s) { + s = s[:start] + edit.New + s[stop+1:] + } else { + s = s[:start] + edit.New + } + } + return s, nil +} + type QueryCatalog struct { catalog core.Catalog ctes map[string]core.Table @@ -653,6 +783,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) { case nodes.ColumnRef: if HasStarRef(n) { + // TODO: This code is copied in func expand() for _, t := range tables { scope := join(n.Fields, ".") if scope != "" && scope != t.Name { @@ -916,24 +1047,6 @@ func findParameters(root nodes.Node) []paramRef { return refs } -type starWalker struct { - found bool -} - -func (s *starWalker) Visit(node nodes.Node) Visitor { - if _, ok := node.(nodes.A_Star); ok { - s.found = true - return nil - } - return s -} - -func needsEdit(root nodes.Node) bool { - v := &starWalker{} - Walk(v, root) - return v.found -} - type nodeSearch struct { list nodes.List check func(nodes.Node) bool diff --git a/internal/dinosql/parser_test.go b/internal/dinosql/parser_test.go index 64437c8745..398a18c1f6 100644 --- a/internal/dinosql/parser_test.go +++ b/internal/dinosql/parser_test.go @@ -26,10 +26,10 @@ func TestPluck(t *testing.T) { } expected := []string{ - "SELECT * FROM venue WHERE slug = $1 AND city = $2", - "SELECT * FROM venue WHERE slug = $1", - "SELECT * FROM venue LIMIT $1", - "SELECT * FROM venue OFFSET $1", + "\nSELECT * FROM venue WHERE slug = $1 AND city = $2", + "\nSELECT * FROM venue WHERE slug = $1", + "\nSELECT * FROM venue LIMIT $1", + "\nSELECT * FROM venue OFFSET $1", } for i, stmt := range tree.Statements { @@ -220,3 +220,21 @@ func TestParseMetadata(t *testing.T) { } } } + +func TestExpand(t *testing.T) { + // pretend that foo has two columns, a and b + raw := `SELECT *, *, foo.* FROM foo` + expected := `SELECT a, b, a, b, foo.a, foo.b FROM foo` + edits := []edit{ + {7, "*", "a, b"}, + {10, "*", "a, b"}, + {13, "foo.*", "foo.a, foo.b"}, + } + actual, err := editQuery(raw, edits) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("mismatch:\nexpected: %s\n acutal: %s", expected, actual) + } +} diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index 2f183556d3..2ab89701cd 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -7,7 +7,6 @@ import ( core "github.com/kyleconroy/sqlc/internal/pg" - "github.com/davecgh/go-spew/spew" "github.com/google/go-cmp/cmp" pg "github.com/lfittl/pg_query_go" ) @@ -26,8 +25,6 @@ func parseSQL(in string) (Query, error) { if q == nil { return Query{}, err } - q.SQL = "" - q.NeedsEdit = false return *q, err } @@ -737,6 +734,85 @@ func TestQueries(t *testing.T) { }, }, }, + { + "star-expansion", + ` + CREATE TABLE foo (a text, b text); + SELECT *, *, foo.* FROM foo; + `, + Query{ + Columns: []core.Column{ + {Name: "a", DataType: "text", Table: public("foo")}, + {Name: "b", DataType: "text", Table: public("foo")}, + {Name: "a", DataType: "text", Table: public("foo")}, + {Name: "b", DataType: "text", Table: public("foo")}, + {Name: "a", DataType: "text", Scope: "foo", Table: public("foo")}, + {Name: "b", DataType: "text", Scope: "foo", Table: public("foo")}, + }, + SQL: "SELECT a, b, a, b, foo.a, foo.b FROM foo", + }, + }, + { + "star-expansion-subquery", + ` + CREATE TABLE foo (a text, b text); + SELECT * FROM foo WHERE EXISTS (SELECT * FROM foo); + `, + Query{ + Columns: []core.Column{ + {Name: "a", DataType: "text", Table: public("foo")}, + {Name: "b", DataType: "text", Table: public("foo")}, + }, + SQL: "SELECT a, b FROM foo WHERE EXISTS (SELECT a, b FROM foo)", + }, + }, + { + "star-expansion-cte", + ` + CREATE TABLE foo (a text, b text); + CREATE TABLE bar (c text, d text); + WITH cte AS (SELECT * FROM foo) SELECT * FROM bar; + `, + Query{ + Columns: []core.Column{ + {Name: "c", DataType: "text", Table: public("bar")}, + {Name: "d", DataType: "text", Table: public("bar")}, + }, + SQL: "WITH cte AS (SELECT a, b FROM foo) SELECT c, d FROM bar", + }, + }, + { + "star-expansion-from-cte", + ` + CREATE TABLE foo (a text, b text); + CREATE TABLE bar (c text, d text); + WITH cte AS (SELECT * FROM foo) SELECT * FROM cte; + `, + Query{ + Columns: []core.Column{ + {Name: "a", DataType: "text"}, + {Name: "b", DataType: "text"}, + }, + SQL: "WITH cte AS (SELECT a, b FROM foo) SELECT a, b FROM cte", + }, + }, + { + "star-expansion-join", + ` + CREATE TABLE foo (a text, b text); + CREATE TABLE bar (c text, d text); + SELECT * FROM foo, bar; + `, + Query{ + Columns: []core.Column{ + {Name: "a", DataType: "text", Table: public("foo")}, + {Name: "b", DataType: "text", Table: public("foo")}, + {Name: "c", DataType: "text", Table: public("bar")}, + {Name: "d", DataType: "text", Table: public("bar")}, + }, + SQL: "SELECT a, b, c, d FROM foo, bar", + }, + }, } { test := tc t.Run(test.name, func(t *testing.T) { @@ -744,6 +820,9 @@ func TestQueries(t *testing.T) { if err != nil { t.Fatal(err) } + if test.query.SQL == "" { + q.SQL = "" + } if diff := cmp.Diff(test.query, q); diff != "" { t.Errorf("query mismatch: \n%s", diff) } @@ -765,6 +844,7 @@ func TestComparisonOperators(t *testing.T) { t.Fatal(err) } expected := Query{ + SQL: q.SQL, Columns: []core.Column{ {Name: "", DataType: "bool", NotNull: true}, }, @@ -776,59 +856,6 @@ func TestComparisonOperators(t *testing.T) { } } -func TestStarWalker(t *testing.T) { - for i, tc := range []struct { - stmt string - expected bool - }{ - { - ` - SELECT * FROM city ORDER BY name; - `, - true, - }, - { - ` - INSERT INTO city ( - name, - slug - ) VALUES ( - $1, - $2 - ) RETURNING *; - `, - true, - }, - { - ` - UPDATE city SET name = $2 WHERE slug = $1; - `, - false, - }, - { - ` - UPDATE venue - SET name = $2 - WHERE slug = $1 - RETURNING *; - `, - true, - }, - } { - test := tc - t.Run(strconv.Itoa(i), func(t *testing.T) { - tree, err := pg.Parse(test.stmt) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(test.expected, needsEdit(tree.Statements[0])); diff != "" { - spew.Dump(tree.Statements[0]) - t.Errorf("query mismatch: \n%s", diff) - } - }) - } -} - func TestInvalidQueries(t *testing.T) { for i, tc := range []struct { stmt string