Skip to content

internal/dinosql: Implement robust expansion #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions internal/dinosql/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,32 +666,13 @@ func (r Result) GoQueries() []GoQuery {
continue
}

code := query.SQL

// TODO: Will horribly break sometimes
if query.NeedsEdit {
var cols []string
find := "*"
for _, c := range query.Columns {
if c.Scope != "" {
find = c.Scope + ".*"
}
name := c.Name
if c.Scope != "" {
name = c.Scope + "." + name
}
cols = append(cols, name)
}
code = strings.Replace(query.SQL, find, strings.Join(cols, ", "), 1)
}

gq := GoQuery{
Cmd: query.Cmd,
ConstantName: lowerTitle(query.Name),
FieldName: lowerTitle(query.Name) + "Stmt",
MethodName: query.Name,
SourceName: query.Filename,
SQL: code,
SQL: query.SQL,
Comments: query.Comments,
}

Expand Down
173 changes: 143 additions & 30 deletions internal/dinosql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ type Query struct {
Comments []string

// XXX: Hack
NeedsEdit bool
Filename string
Filename string
}

type Result struct {
Expand Down Expand Up @@ -289,7 +288,7 @@ func lineno(source string, head int) (int, int) {
func pluckQuery(source string, n nodes.RawStmt) (string, error) {
head := n.StmtLocation
tail := n.StmtLocation + n.StmtLen
return strings.TrimSpace(source[head:tail]), nil
return source[head:tail], nil
}

func rangeVars(root nodes.Node) []nodes.RangeVar {
Expand Down Expand Up @@ -403,7 +402,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
if err := validateFuncCall(&c, raw); err != nil {
return nil, err
}
name, cmd, err := parseMetadata(rawSQL)
name, cmd, err := parseMetadata(strings.TrimSpace(rawSQL))
if err != nil {
return nil, err
}
Expand All @@ -421,20 +420,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
if err != nil {
return nil, err
}
expanded, err := expand(c, raw, rawSQL)
if err != nil {
return nil, err
}

trimmed, comments, err := stripComments(rawSQL)
trimmed, comments, err := stripComments(strings.TrimSpace(expanded))
if err != nil {
return nil, err
}

return &Query{
Cmd: cmd,
Comments: comments,
Name: name,
Params: params,
Columns: cols,
SQL: trimmed,
NeedsEdit: needsEdit(stmt),
Cmd: cmd,
Comments: comments,
Name: name,
Params: params,
Columns: cols,
SQL: trimmed,
}, nil
}

Expand All @@ -454,6 +456,134 @@ func stripComments(sql string) (string, []string, error) {
return strings.Join(lines, "\n"), comments, s.Err()
}

type edit struct {
Location int
Old string
New string
}

func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
list := search(raw, func(node nodes.Node) bool {
switch node.(type) {
case nodes.DeleteStmt:
case nodes.InsertStmt:
case nodes.SelectStmt:
case nodes.UpdateStmt:
default:
return false
}
return true
})
if len(list.Items) == 0 {
return sql, nil
}
var edits []edit
for _, item := range list.Items {
edit, err := expandStmt(c, raw, item)
if err != nil {
return "", err
}
edits = append(edits, edit...)
}
return editQuery(sql, edits)
}

func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
tables, err := sourceTables(c, node)
if err != nil {
return nil, err
}

var targets nodes.List
switch n := node.(type) {
case nodes.DeleteStmt:
targets = n.ReturningList
case nodes.InsertStmt:
targets = n.ReturningList
case nodes.SelectStmt:
targets = n.TargetList
case nodes.UpdateStmt:
targets = n.ReturningList
default:
return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
}

var edits []edit
for _, target := range targets.Items {
res, ok := target.(nodes.ResTarget)
if !ok {
continue
}
ref, ok := res.Val.(nodes.ColumnRef)
if !ok {
continue
}
if !HasStarRef(ref) {
continue
}
var parts, cols []string
for _, f := range ref.Fields.Items {
switch field := f.(type) {
case nodes.String:
parts = append(parts, field.Str)
case nodes.A_Star:
parts = append(parts, "*")
default:
return nil, fmt.Errorf("unknown field in ColumnRef: %T", f)
}
}
for _, t := range tables {
scope := join(ref.Fields, ".")
if scope != "" && scope != t.Name {
continue
}
for _, c := range t.Columns {
cname := c.Name
if res.Name != nil {
cname = *res.Name
}
if scope != "" {
cname = scope + "." + cname
}
cols = append(cols, cname)
}
}
edits = append(edits, edit{
Location: res.Location - raw.StmtLocation,
Old: strings.Join(parts, "."),
New: strings.Join(cols, ", "),
})
}
return edits, nil
}

func editQuery(raw string, a []edit) (string, error) {
if len(a) == 0 {
return raw, nil
}
sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location })
s := raw
for _, edit := range a {
start := edit.Location
if start > len(s) {
return "", fmt.Errorf("edit start location is out of bounds")
}
if len(edit.New) <= 0 {
return "", fmt.Errorf("empty edit contents")
}
if len(edit.Old) <= 0 {
return "", fmt.Errorf("empty edit contents")
}
stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty
if stop < len(s) {
s = s[:start] + edit.New + s[stop+1:]
} else {
s = s[:start] + edit.New
}
}
return s, nil
}

type QueryCatalog struct {
catalog core.Catalog
ctes map[string]core.Table
Expand Down Expand Up @@ -653,6 +783,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {

case nodes.ColumnRef:
if HasStarRef(n) {
// TODO: This code is copied in func expand()
for _, t := range tables {
scope := join(n.Fields, ".")
if scope != "" && scope != t.Name {
Expand Down Expand Up @@ -916,24 +1047,6 @@ func findParameters(root nodes.Node) []paramRef {
return refs
}

type starWalker struct {
found bool
}

func (s *starWalker) Visit(node nodes.Node) Visitor {
if _, ok := node.(nodes.A_Star); ok {
s.found = true
return nil
}
return s
}

func needsEdit(root nodes.Node) bool {
v := &starWalker{}
Walk(v, root)
return v.found
}

type nodeSearch struct {
list nodes.List
check func(nodes.Node) bool
Expand Down
26 changes: 22 additions & 4 deletions internal/dinosql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func TestPluck(t *testing.T) {
}

expected := []string{
"SELECT * FROM venue WHERE slug = $1 AND city = $2",
"SELECT * FROM venue WHERE slug = $1",
"SELECT * FROM venue LIMIT $1",
"SELECT * FROM venue OFFSET $1",
"\nSELECT * FROM venue WHERE slug = $1 AND city = $2",
"\nSELECT * FROM venue WHERE slug = $1",
"\nSELECT * FROM venue LIMIT $1",
"\nSELECT * FROM venue OFFSET $1",
}

for i, stmt := range tree.Statements {
Expand Down Expand Up @@ -220,3 +220,21 @@ func TestParseMetadata(t *testing.T) {
}
}
}

func TestExpand(t *testing.T) {
// pretend that foo has two columns, a and b
raw := `SELECT *, *, foo.* FROM foo`
expected := `SELECT a, b, a, b, foo.a, foo.b FROM foo`
edits := []edit{
{7, "*", "a, b"},
{10, "*", "a, b"},
{13, "foo.*", "foo.a, foo.b"},
}
actual, err := editQuery(raw, edits)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("mismatch:\nexpected: %s\n acutal: %s", expected, actual)
}
}
Loading