@@ -103,6 +103,14 @@ func isNotNull(n nodes.ColumnDef) bool {
103
103
return false
104
104
}
105
105
106
+ func isStar (n nodes.ColumnRef ) bool {
107
+ if len (n .Fields .Items ) != 1 {
108
+ return false
109
+ }
110
+ _ , aStar := n .Fields .Items [0 ].(nodes.A_Star )
111
+ return aStar
112
+ }
113
+
106
114
type Query struct {
107
115
Type string
108
116
MethodName string
@@ -112,6 +120,7 @@ type Query struct {
112
120
Args []Arg
113
121
Table postgres.Table
114
122
ReturnType string
123
+ ScanRecord bool
115
124
}
116
125
117
126
type Result struct {
@@ -179,46 +188,50 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
179
188
if ! ok {
180
189
continue
181
190
}
182
-
183
191
switch n := raw .Stmt .(type ) {
184
192
case nodes.SelectStmt :
185
- t := tableName (n )
186
- c := columnNames (s , t )
193
+ case nodes.DeleteStmt :
194
+ case nodes.InsertStmt :
195
+ default :
196
+ log .Printf ("%T\n " , n )
197
+ continue
198
+ }
187
199
188
- rawSQL , _ := pluckQuery ( source , raw )
189
- refs := extractArgs ( n )
200
+ t := tableName ( raw . Stmt )
201
+ c := columnNames ( s , t )
190
202
191
- tab := getTable (s , t )
192
- r .Queries [i ].Table = tab
193
- r .Queries [i ].ReturnType = tab .GoName
194
- r .Queries [i ].Args = parseArgs (tab , refs )
195
- r .Queries [i ].SQL = strings .Replace (rawSQL , "*" , strings .Join (c , ", " ), 1 )
196
- case nodes.DeleteStmt :
197
- t := tableName (n )
203
+ rawSQL , _ := pluckQuery (source , raw )
204
+ refs := extractArgs (raw .Stmt )
205
+ outs := findOutputs (nil , raw .Stmt )
198
206
199
- rawSQL , _ := pluckQuery (source , raw )
200
- refs := extractArgs (n )
207
+ tab := getTable (s , t )
208
+ r .Queries [i ].Table = tab
209
+ r .Queries [i ].Args = parseArgs (tab , refs )
201
210
202
- tab := getTable (s , t )
203
- r .Queries [i ].Table = tab
204
- r .Queries [i ].ReturnType = tab .GoName
205
- r .Queries [i ].Args = parseArgs (tab , refs )
211
+ if len (outs ) == 0 {
206
212
r .Queries [i ].SQL = rawSQL
207
- case nodes.InsertStmt :
208
- t := tableName (n )
209
- c := columnNames (s , t )
210
- rawSQL , _ := pluckQuery (source , raw )
211
- refs := extractArgs (n )
212
-
213
- tab := getTable (s , t )
214
- r .Queries [i ].Table = tab
213
+ } else if len (outs ) == 1 && isStar (outs [0 ]) {
215
214
r .Queries [i ].ReturnType = tab .GoName
216
- r .Queries [i ].Args = parseArgs ( tab , refs )
215
+ r .Queries [i ].ScanRecord = true
217
216
r .Queries [i ].SQL = strings .Replace (rawSQL , "*" , strings .Join (c , ", " ), 1 )
218
- default :
219
- log .Printf ("%T\n " , n )
217
+ } else {
218
+ r .Queries [i ].ReturnType = returnType (tab , outs )
219
+ r .Queries [i ].SQL = rawSQL
220
+ }
221
+ }
222
+ }
223
+
224
+ func returnType (t postgres.Table , refs []nodes.ColumnRef ) string {
225
+ if len (refs ) != 1 {
226
+ panic ("too many return columns" )
227
+ }
228
+ name := join (refs [0 ].Fields , "." )
229
+ for _ , c := range t .Columns {
230
+ if c .Name == name {
231
+ return c .GoType ()
220
232
}
221
233
}
234
+ return "interface{}"
222
235
}
223
236
224
237
func extractArgs (n nodes.Node ) []paramRef {
@@ -268,6 +281,32 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
268
281
ref : n ,
269
282
})
270
283
case nodes.ColumnRef :
284
+ case nodes.FuncCall :
285
+ case nil :
286
+ default :
287
+ log .Printf ("%T\n " , n )
288
+ }
289
+ return r
290
+ }
291
+
292
+ func findOutputs (r []nodes.ColumnRef , n nodes.Node ) []nodes.ColumnRef {
293
+ switch n := n .(type ) {
294
+ case nodes.RawStmt :
295
+ r = findOutputs (r , n .Stmt )
296
+ case nodes.DeleteStmt :
297
+ r = findOutputs (r , n .ReturningList )
298
+ case nodes.SelectStmt :
299
+ r = findOutputs (r , n .TargetList )
300
+ case nodes.InsertStmt :
301
+ r = findOutputs (r , n .ReturningList )
302
+ case nodes.List :
303
+ for _ , i := range n .Items {
304
+ r = findOutputs (r , i )
305
+ }
306
+ case nodes.ResTarget :
307
+ r = findOutputs (r , n .Val )
308
+ case nodes.ColumnRef :
309
+ r = append (r , n )
271
310
case nil :
272
311
default :
273
312
log .Printf ("%T\n " , n )
@@ -414,8 +453,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}}
414
453
default:
415
454
row = q.db.QueryRowContext(ctx, {{.QueryName}}, {{range .Args}}{{.Name}},{{end}})
416
455
}
417
- i := {{.Table.GoName}}{}
456
+ var i {{.ReturnType}}
457
+ {{- if .ScanRecord}}
418
458
err := row.Scan({{range .Table.Columns}}&i.{{.GoName}},{{end}})
459
+ {{- else}}
460
+ err := row.Scan(&i)
461
+ {{- end}}
419
462
return i, err
420
463
}
421
464
{{end}}
@@ -436,10 +479,14 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}}
436
479
return nil, err
437
480
}
438
481
defer rows.Close()
439
- items := []{{.Table.GoName }}{}
482
+ items := []{{.ReturnType }}{}
440
483
for rows.Next() {
441
- i := {{.Table.GoName}}{}
484
+ var i {{.ReturnType}}
485
+ {{- if .ScanRecord}}
442
486
if err := rows.Scan({{range .Table.Columns}}&i.{{.GoName}},{{end}}); err != nil {
487
+ {{- else}}
488
+ if err := rows.Scan(&i); err != nil {
489
+ {{- end}}
443
490
return nil, err
444
491
}
445
492
items = append(items, i)
@@ -525,6 +572,7 @@ func generate(r *Result, pkg string) string {
525
572
w .Flush ()
526
573
code , err := format .Source (b .Bytes ())
527
574
if err != nil {
575
+ fmt .Println (b .String ())
528
576
panic (fmt .Errorf ("source error: %s" , err ))
529
577
}
530
578
return string (code )
0 commit comments