@@ -167,8 +167,7 @@ type Query struct {
167
167
Comments []string
168
168
169
169
// XXX: Hack
170
- NeedsEdit bool
171
- Filename string
170
+ Filename string
172
171
}
173
172
174
173
type Result struct {
@@ -289,7 +288,7 @@ func lineno(source string, head int) (int, int) {
289
288
func pluckQuery (source string , n nodes.RawStmt ) (string , error ) {
290
289
head := n .StmtLocation
291
290
tail := n .StmtLocation + n .StmtLen
292
- return strings . TrimSpace ( source [head :tail ]) , nil
291
+ return source [head :tail ], nil
293
292
}
294
293
295
294
func rangeVars (root nodes.Node ) []nodes.RangeVar {
@@ -403,7 +402,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
403
402
if err := validateFuncCall (& c , raw ); err != nil {
404
403
return nil , err
405
404
}
406
- name , cmd , err := parseMetadata (rawSQL )
405
+ name , cmd , err := parseMetadata (strings . TrimSpace ( rawSQL ) )
407
406
if err != nil {
408
407
return nil , err
409
408
}
@@ -421,20 +420,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
421
420
if err != nil {
422
421
return nil , err
423
422
}
423
+ expanded , err := expand (c , raw , rawSQL )
424
+ if err != nil {
425
+ return nil , err
426
+ }
424
427
425
- trimmed , comments , err := stripComments (rawSQL )
428
+ trimmed , comments , err := stripComments (strings . TrimSpace ( expanded ) )
426
429
if err != nil {
427
430
return nil , err
428
431
}
429
432
430
433
return & Query {
431
- Cmd : cmd ,
432
- Comments : comments ,
433
- Name : name ,
434
- Params : params ,
435
- Columns : cols ,
436
- SQL : trimmed ,
437
- NeedsEdit : needsEdit (stmt ),
434
+ Cmd : cmd ,
435
+ Comments : comments ,
436
+ Name : name ,
437
+ Params : params ,
438
+ Columns : cols ,
439
+ SQL : trimmed ,
438
440
}, nil
439
441
}
440
442
@@ -454,6 +456,134 @@ func stripComments(sql string) (string, []string, error) {
454
456
return strings .Join (lines , "\n " ), comments , s .Err ()
455
457
}
456
458
459
+ type edit struct {
460
+ Location int
461
+ Old string
462
+ New string
463
+ }
464
+
465
+ func expand (c core.Catalog , raw nodes.RawStmt , sql string ) (string , error ) {
466
+ list := search (raw , func (node nodes.Node ) bool {
467
+ switch node .(type ) {
468
+ case nodes.DeleteStmt :
469
+ case nodes.InsertStmt :
470
+ case nodes.SelectStmt :
471
+ case nodes.UpdateStmt :
472
+ default :
473
+ return false
474
+ }
475
+ return true
476
+ })
477
+ if len (list .Items ) == 0 {
478
+ return sql , nil
479
+ }
480
+ var edits []edit
481
+ for _ , item := range list .Items {
482
+ edit , err := expandStmt (c , raw , item )
483
+ if err != nil {
484
+ return "" , err
485
+ }
486
+ edits = append (edits , edit ... )
487
+ }
488
+ return editQuery (sql , edits )
489
+ }
490
+
491
+ func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
492
+ tables , err := sourceTables (c , node )
493
+ if err != nil {
494
+ return nil , err
495
+ }
496
+
497
+ var targets nodes.List
498
+ switch n := node .(type ) {
499
+ case nodes.DeleteStmt :
500
+ targets = n .ReturningList
501
+ case nodes.InsertStmt :
502
+ targets = n .ReturningList
503
+ case nodes.SelectStmt :
504
+ targets = n .TargetList
505
+ case nodes.UpdateStmt :
506
+ targets = n .ReturningList
507
+ default :
508
+ return nil , fmt .Errorf ("outputColumns: unsupported node type: %T" , n )
509
+ }
510
+
511
+ var edits []edit
512
+ for _ , target := range targets .Items {
513
+ res , ok := target .(nodes.ResTarget )
514
+ if ! ok {
515
+ continue
516
+ }
517
+ ref , ok := res .Val .(nodes.ColumnRef )
518
+ if ! ok {
519
+ continue
520
+ }
521
+ if ! HasStarRef (ref ) {
522
+ continue
523
+ }
524
+ var parts , cols []string
525
+ for _ , f := range ref .Fields .Items {
526
+ switch field := f .(type ) {
527
+ case nodes.String :
528
+ parts = append (parts , field .Str )
529
+ case nodes.A_Star :
530
+ parts = append (parts , "*" )
531
+ default :
532
+ return nil , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
533
+ }
534
+ }
535
+ for _ , t := range tables {
536
+ scope := join (ref .Fields , "." )
537
+ if scope != "" && scope != t .Name {
538
+ continue
539
+ }
540
+ for _ , c := range t .Columns {
541
+ cname := c .Name
542
+ if res .Name != nil {
543
+ cname = * res .Name
544
+ }
545
+ if scope != "" {
546
+ cname = scope + "." + cname
547
+ }
548
+ cols = append (cols , cname )
549
+ }
550
+ }
551
+ edits = append (edits , edit {
552
+ Location : res .Location - raw .StmtLocation ,
553
+ Old : strings .Join (parts , "." ),
554
+ New : strings .Join (cols , ", " ),
555
+ })
556
+ }
557
+ return edits , nil
558
+ }
559
+
560
+ func editQuery (raw string , a []edit ) (string , error ) {
561
+ if len (a ) == 0 {
562
+ return raw , nil
563
+ }
564
+ sort .Slice (a , func (i , j int ) bool { return a [i ].Location > a [j ].Location })
565
+ s := raw
566
+ for _ , edit := range a {
567
+ start := edit .Location
568
+ if start > len (s ) {
569
+ return "" , fmt .Errorf ("edit start location is out of bounds" )
570
+ }
571
+ if len (edit .New ) <= 0 {
572
+ return "" , fmt .Errorf ("empty edit contents" )
573
+ }
574
+ if len (edit .Old ) <= 0 {
575
+ return "" , fmt .Errorf ("empty edit contents" )
576
+ }
577
+ stop := edit .Location + len (edit .Old ) - 1 // Assumes edit.New is non-empty
578
+ if stop < len (s ) {
579
+ s = s [:start ] + edit .New + s [stop + 1 :]
580
+ } else {
581
+ s = s [:start ] + edit .New
582
+ }
583
+ }
584
+ return s , nil
585
+ }
586
+
457
587
type QueryCatalog struct {
458
588
catalog core.Catalog
459
589
ctes map [string ]core.Table
@@ -653,6 +783,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {
653
783
654
784
case nodes.ColumnRef :
655
785
if HasStarRef (n ) {
786
+ // TODO: This code is copied in func expand()
656
787
for _ , t := range tables {
657
788
scope := join (n .Fields , "." )
658
789
if scope != "" && scope != t .Name {
@@ -916,24 +1047,6 @@ func findParameters(root nodes.Node) []paramRef {
916
1047
return refs
917
1048
}
918
1049
919
- type starWalker struct {
920
- found bool
921
- }
922
-
923
- func (s * starWalker ) Visit (node nodes.Node ) Visitor {
924
- if _ , ok := node .(nodes.A_Star ); ok {
925
- s .found = true
926
- return nil
927
- }
928
- return s
929
- }
930
-
931
- func needsEdit (root nodes.Node ) bool {
932
- v := & starWalker {}
933
- Walk (v , root )
934
- return v .found
935
- }
936
-
937
1050
type nodeSearch struct {
938
1051
list nodes.List
939
1052
check func (nodes.Node ) bool
0 commit comments