@@ -420,13 +420,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
420
420
if err != nil {
421
421
return nil , err
422
422
}
423
-
424
- // TODO: Limit calls to sourceTables
425
- tables , err := sourceTables (c , raw .Stmt )
426
- if err != nil {
427
- return nil , err
428
- }
429
- expanded , err := expand (raw , tables , rawSQL )
423
+ expanded , err := expand (c , raw , rawSQL )
430
424
if err != nil {
431
425
return nil , err
432
426
}
@@ -468,25 +462,65 @@ type edit struct {
468
462
New string
469
463
}
470
464
471
- func expand (raw nodes. RawStmt , tables []core. Table , sql string ) (string , error ) {
465
+ func expand (c core. Catalog , raw nodes. RawStmt , sql string ) (string , error ) {
472
466
list := search (raw , func (node nodes.Node ) bool {
473
- res , ok := node .(nodes. ResTarget )
474
- if ! ok {
475
- return false
476
- }
477
- ref , ok := res . Val .( nodes.ColumnRef )
478
- if ! ok {
467
+ switch node .(type ) {
468
+ case nodes. DeleteStmt :
469
+ case nodes. InsertStmt :
470
+ case nodes. SelectStmt :
471
+ case nodes.UpdateStmt :
472
+ default :
479
473
return false
480
474
}
481
- return HasStarRef ( ref )
475
+ return true
482
476
})
483
477
if len (list .Items ) == 0 {
484
478
return sql , nil
485
479
}
486
480
var edits []edit
487
481
for _ , item := range list .Items {
488
- res := item .(nodes.ResTarget )
489
- ref := res .Val .(nodes.ColumnRef )
482
+ edit , err := expandStmt (c , raw , item )
483
+ if err != nil {
484
+ return "" , err
485
+ }
486
+ edits = append (edits , edit ... )
487
+ }
488
+ return editQuery (sql , edits )
489
+ }
490
+
491
+ func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
492
+ tables , err := sourceTables (c , node )
493
+ if err != nil {
494
+ return nil , err
495
+ }
496
+
497
+ var targets nodes.List
498
+ switch n := node .(type ) {
499
+ case nodes.DeleteStmt :
500
+ targets = n .ReturningList
501
+ case nodes.InsertStmt :
502
+ targets = n .ReturningList
503
+ case nodes.SelectStmt :
504
+ targets = n .TargetList
505
+ case nodes.UpdateStmt :
506
+ targets = n .ReturningList
507
+ default :
508
+ return nil , fmt .Errorf ("outputColumns: unsupported node type: %T" , n )
509
+ }
510
+
511
+ var edits []edit
512
+ for _ , target := range targets .Items {
513
+ res , ok := target .(nodes.ResTarget )
514
+ if ! ok {
515
+ continue
516
+ }
517
+ ref , ok := res .Val .(nodes.ColumnRef )
518
+ if ! ok {
519
+ continue
520
+ }
521
+ if ! HasStarRef (ref ) {
522
+ continue
523
+ }
490
524
var parts , cols []string
491
525
for _ , f := range ref .Fields .Items {
492
526
switch field := f .(type ) {
@@ -495,7 +529,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
495
529
case nodes.A_Star :
496
530
parts = append (parts , "*" )
497
531
default :
498
- return "" , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
532
+ return nil , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
499
533
}
500
534
}
501
535
for _ , t := range tables {
@@ -520,7 +554,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
520
554
New : strings .Join (cols , ", " ),
521
555
})
522
556
}
523
- return editQuery ( sql , edits )
557
+ return edits , nil
524
558
}
525
559
526
560
func editQuery (raw string , a []edit ) (string , error ) {
0 commit comments