Skip to content

Commit c869859

Browse files
authored
internal/dinosql: Implement robust expansion (#186)
* internal/dinosql: Implement robust expansion Expand `SELECT|RETURNING *` into the correct columns, no matter the location in the query
2 parents b393d65 + 573256a commit c869859

File tree

4 files changed

+249
-110
lines changed

4 files changed

+249
-110
lines changed

internal/dinosql/gen.go

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -666,32 +666,13 @@ func (r Result) GoQueries() []GoQuery {
666666
continue
667667
}
668668

669-
code := query.SQL
670-
671-
// TODO: Will horribly break sometimes
672-
if query.NeedsEdit {
673-
var cols []string
674-
find := "*"
675-
for _, c := range query.Columns {
676-
if c.Scope != "" {
677-
find = c.Scope + ".*"
678-
}
679-
name := c.Name
680-
if c.Scope != "" {
681-
name = c.Scope + "." + name
682-
}
683-
cols = append(cols, name)
684-
}
685-
code = strings.Replace(query.SQL, find, strings.Join(cols, ", "), 1)
686-
}
687-
688669
gq := GoQuery{
689670
Cmd: query.Cmd,
690671
ConstantName: lowerTitle(query.Name),
691672
FieldName: lowerTitle(query.Name) + "Stmt",
692673
MethodName: query.Name,
693674
SourceName: query.Filename,
694-
SQL: code,
675+
SQL: query.SQL,
695676
Comments: query.Comments,
696677
}
697678

internal/dinosql/parser.go

Lines changed: 143 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ type Query struct {
167167
Comments []string
168168

169169
// XXX: Hack
170-
NeedsEdit bool
171-
Filename string
170+
Filename string
172171
}
173172

174173
type Result struct {
@@ -289,7 +288,7 @@ func lineno(source string, head int) (int, int) {
289288
func pluckQuery(source string, n nodes.RawStmt) (string, error) {
290289
head := n.StmtLocation
291290
tail := n.StmtLocation + n.StmtLen
292-
return strings.TrimSpace(source[head:tail]), nil
291+
return source[head:tail], nil
293292
}
294293

295294
func rangeVars(root nodes.Node) []nodes.RangeVar {
@@ -403,7 +402,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
403402
if err := validateFuncCall(&c, raw); err != nil {
404403
return nil, err
405404
}
406-
name, cmd, err := parseMetadata(rawSQL)
405+
name, cmd, err := parseMetadata(strings.TrimSpace(rawSQL))
407406
if err != nil {
408407
return nil, err
409408
}
@@ -421,20 +420,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
421420
if err != nil {
422421
return nil, err
423422
}
423+
expanded, err := expand(c, raw, rawSQL)
424+
if err != nil {
425+
return nil, err
426+
}
424427

425-
trimmed, comments, err := stripComments(rawSQL)
428+
trimmed, comments, err := stripComments(strings.TrimSpace(expanded))
426429
if err != nil {
427430
return nil, err
428431
}
429432

430433
return &Query{
431-
Cmd: cmd,
432-
Comments: comments,
433-
Name: name,
434-
Params: params,
435-
Columns: cols,
436-
SQL: trimmed,
437-
NeedsEdit: needsEdit(stmt),
434+
Cmd: cmd,
435+
Comments: comments,
436+
Name: name,
437+
Params: params,
438+
Columns: cols,
439+
SQL: trimmed,
438440
}, nil
439441
}
440442

@@ -454,6 +456,134 @@ func stripComments(sql string) (string, []string, error) {
454456
return strings.Join(lines, "\n"), comments, s.Err()
455457
}
456458

459+
type edit struct {
460+
Location int
461+
Old string
462+
New string
463+
}
464+
465+
func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
466+
list := search(raw, func(node nodes.Node) bool {
467+
switch node.(type) {
468+
case nodes.DeleteStmt:
469+
case nodes.InsertStmt:
470+
case nodes.SelectStmt:
471+
case nodes.UpdateStmt:
472+
default:
473+
return false
474+
}
475+
return true
476+
})
477+
if len(list.Items) == 0 {
478+
return sql, nil
479+
}
480+
var edits []edit
481+
for _, item := range list.Items {
482+
edit, err := expandStmt(c, raw, item)
483+
if err != nil {
484+
return "", err
485+
}
486+
edits = append(edits, edit...)
487+
}
488+
return editQuery(sql, edits)
489+
}
490+
491+
func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
492+
tables, err := sourceTables(c, node)
493+
if err != nil {
494+
return nil, err
495+
}
496+
497+
var targets nodes.List
498+
switch n := node.(type) {
499+
case nodes.DeleteStmt:
500+
targets = n.ReturningList
501+
case nodes.InsertStmt:
502+
targets = n.ReturningList
503+
case nodes.SelectStmt:
504+
targets = n.TargetList
505+
case nodes.UpdateStmt:
506+
targets = n.ReturningList
507+
default:
508+
return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
509+
}
510+
511+
var edits []edit
512+
for _, target := range targets.Items {
513+
res, ok := target.(nodes.ResTarget)
514+
if !ok {
515+
continue
516+
}
517+
ref, ok := res.Val.(nodes.ColumnRef)
518+
if !ok {
519+
continue
520+
}
521+
if !HasStarRef(ref) {
522+
continue
523+
}
524+
var parts, cols []string
525+
for _, f := range ref.Fields.Items {
526+
switch field := f.(type) {
527+
case nodes.String:
528+
parts = append(parts, field.Str)
529+
case nodes.A_Star:
530+
parts = append(parts, "*")
531+
default:
532+
return nil, fmt.Errorf("unknown field in ColumnRef: %T", f)
533+
}
534+
}
535+
for _, t := range tables {
536+
scope := join(ref.Fields, ".")
537+
if scope != "" && scope != t.Name {
538+
continue
539+
}
540+
for _, c := range t.Columns {
541+
cname := c.Name
542+
if res.Name != nil {
543+
cname = *res.Name
544+
}
545+
if scope != "" {
546+
cname = scope + "." + cname
547+
}
548+
cols = append(cols, cname)
549+
}
550+
}
551+
edits = append(edits, edit{
552+
Location: res.Location - raw.StmtLocation,
553+
Old: strings.Join(parts, "."),
554+
New: strings.Join(cols, ", "),
555+
})
556+
}
557+
return edits, nil
558+
}
559+
560+
func editQuery(raw string, a []edit) (string, error) {
561+
if len(a) == 0 {
562+
return raw, nil
563+
}
564+
sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location })
565+
s := raw
566+
for _, edit := range a {
567+
start := edit.Location
568+
if start > len(s) {
569+
return "", fmt.Errorf("edit start location is out of bounds")
570+
}
571+
if len(edit.New) <= 0 {
572+
return "", fmt.Errorf("empty edit contents")
573+
}
574+
if len(edit.Old) <= 0 {
575+
return "", fmt.Errorf("empty edit contents")
576+
}
577+
stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty
578+
if stop < len(s) {
579+
s = s[:start] + edit.New + s[stop+1:]
580+
} else {
581+
s = s[:start] + edit.New
582+
}
583+
}
584+
return s, nil
585+
}
586+
457587
type QueryCatalog struct {
458588
catalog core.Catalog
459589
ctes map[string]core.Table
@@ -653,6 +783,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {
653783

654784
case nodes.ColumnRef:
655785
if HasStarRef(n) {
786+
// TODO: This code is copied in func expand()
656787
for _, t := range tables {
657788
scope := join(n.Fields, ".")
658789
if scope != "" && scope != t.Name {
@@ -916,24 +1047,6 @@ func findParameters(root nodes.Node) []paramRef {
9161047
return refs
9171048
}
9181049

919-
type starWalker struct {
920-
found bool
921-
}
922-
923-
func (s *starWalker) Visit(node nodes.Node) Visitor {
924-
if _, ok := node.(nodes.A_Star); ok {
925-
s.found = true
926-
return nil
927-
}
928-
return s
929-
}
930-
931-
func needsEdit(root nodes.Node) bool {
932-
v := &starWalker{}
933-
Walk(v, root)
934-
return v.found
935-
}
936-
9371050
type nodeSearch struct {
9381051
list nodes.List
9391052
check func(nodes.Node) bool

internal/dinosql/parser_test.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ func TestPluck(t *testing.T) {
2626
}
2727

2828
expected := []string{
29-
"SELECT * FROM venue WHERE slug = $1 AND city = $2",
30-
"SELECT * FROM venue WHERE slug = $1",
31-
"SELECT * FROM venue LIMIT $1",
32-
"SELECT * FROM venue OFFSET $1",
29+
"\nSELECT * FROM venue WHERE slug = $1 AND city = $2",
30+
"\nSELECT * FROM venue WHERE slug = $1",
31+
"\nSELECT * FROM venue LIMIT $1",
32+
"\nSELECT * FROM venue OFFSET $1",
3333
}
3434

3535
for i, stmt := range tree.Statements {
@@ -220,3 +220,21 @@ func TestParseMetadata(t *testing.T) {
220220
}
221221
}
222222
}
223+
224+
func TestExpand(t *testing.T) {
225+
// pretend that foo has two columns, a and b
226+
raw := `SELECT *, *, foo.* FROM foo`
227+
expected := `SELECT a, b, a, b, foo.a, foo.b FROM foo`
228+
edits := []edit{
229+
{7, "*", "a, b"},
230+
{10, "*", "a, b"},
231+
{13, "foo.*", "foo.a, foo.b"},
232+
}
233+
actual, err := editQuery(raw, edits)
234+
if err != nil {
235+
t.Error(err)
236+
}
237+
if expected != actual {
238+
t.Errorf("mismatch:\nexpected: %s\n acutal: %s", expected, actual)
239+
}
240+
}

0 commit comments

Comments
 (0)