@@ -2,7 +2,6 @@ package rewrite
2
2
3
3
import (
4
4
"fmt"
5
- "strings"
6
5
7
6
"github.com/kyleconroy/sqlc/internal/sql/ast"
8
7
"github.com/kyleconroy/sqlc/internal/sql/astutils"
@@ -50,29 +49,20 @@ func Embeds(raw *ast.RawStmt) (*ast.RawStmt, EmbedSet) {
50
49
return false
51
50
}
52
51
53
- sw := & stringWalker {}
54
- astutils .Walk (sw , fun .Args )
55
- str := strings .Join (sw .Parts , "." )
56
-
57
- tableName , err := parseTable (sw )
58
- if err != nil {
59
- return false
60
- }
52
+ param , _ := flatten (fun .Args )
61
53
62
54
node := & ast.ColumnRef {
63
55
Fields : & ast.List {
64
- Items : []ast.Node {},
56
+ Items : []ast.Node {
57
+ & ast.String {Str : param },
58
+ & ast.A_Star {},
59
+ },
65
60
},
66
61
}
67
62
68
- for _ , s := range sw .Parts {
69
- node .Fields .Items = append (node .Fields .Items , & ast.String {Str : s })
70
- }
71
- node .Fields .Items = append (node .Fields .Items , & ast.A_Star {})
72
-
73
63
embeds = append (embeds , & Embed {
74
- Table : tableName ,
75
- param : str ,
64
+ Table : & ast. TableName { Name : param } ,
65
+ param : param ,
76
66
Node : node ,
77
67
})
78
68
@@ -99,27 +89,3 @@ func isEmbed(node ast.Node) bool {
99
89
isValid := call .Func .Schema == "sqlc" && call .Func .Name == "embed"
100
90
return isValid
101
91
}
102
-
103
- func parseTable (sw * stringWalker ) (* ast.TableName , error ) {
104
- parts := sw .Parts
105
-
106
- switch len (parts ) {
107
- case 1 :
108
- return & ast.TableName {
109
- Name : parts [0 ],
110
- }, nil
111
- case 2 :
112
- return & ast.TableName {
113
- Schema : parts [0 ],
114
- Name : parts [1 ],
115
- }, nil
116
- case 3 :
117
- return & ast.TableName {
118
- Catalog : parts [0 ],
119
- Schema : parts [1 ],
120
- Name : parts [2 ],
121
- }, nil
122
- default :
123
- return nil , fmt .Errorf ("invalid table name" )
124
- }
125
- }
0 commit comments