Skip to content

Commit 922f5f3

Browse files
committed
Add support for returning one column
1 parent d24b1db commit 922f5f3

File tree

3 files changed

+136
-37
lines changed

3 files changed

+136
-37
lines changed

parser.go

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ func isNotNull(n nodes.ColumnDef) bool {
103103
return false
104104
}
105105

106+
func isStar(n nodes.ColumnRef) bool {
107+
if len(n.Fields.Items) != 1 {
108+
return false
109+
}
110+
_, aStar := n.Fields.Items[0].(nodes.A_Star)
111+
return aStar
112+
}
113+
106114
type Query struct {
107115
Type string
108116
MethodName string
@@ -112,6 +120,7 @@ type Query struct {
112120
Args []Arg
113121
Table postgres.Table
114122
ReturnType string
123+
ScanRecord bool
115124
}
116125

117126
type Result struct {
@@ -179,46 +188,50 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
179188
if !ok {
180189
continue
181190
}
182-
183191
switch n := raw.Stmt.(type) {
184192
case nodes.SelectStmt:
185-
t := tableName(n)
186-
c := columnNames(s, t)
193+
case nodes.DeleteStmt:
194+
case nodes.InsertStmt:
195+
default:
196+
log.Printf("%T\n", n)
197+
continue
198+
}
187199

188-
rawSQL, _ := pluckQuery(source, raw)
189-
refs := extractArgs(n)
200+
t := tableName(raw.Stmt)
201+
c := columnNames(s, t)
190202

191-
tab := getTable(s, t)
192-
r.Queries[i].Table = tab
193-
r.Queries[i].ReturnType = tab.GoName
194-
r.Queries[i].Args = parseArgs(tab, refs)
195-
r.Queries[i].SQL = strings.Replace(rawSQL, "*", strings.Join(c, ", "), 1)
196-
case nodes.DeleteStmt:
197-
t := tableName(n)
203+
rawSQL, _ := pluckQuery(source, raw)
204+
refs := extractArgs(raw.Stmt)
205+
outs := findOutputs(nil, raw.Stmt)
198206

199-
rawSQL, _ := pluckQuery(source, raw)
200-
refs := extractArgs(n)
207+
tab := getTable(s, t)
208+
r.Queries[i].Table = tab
209+
r.Queries[i].Args = parseArgs(tab, refs)
201210

202-
tab := getTable(s, t)
203-
r.Queries[i].Table = tab
204-
r.Queries[i].ReturnType = tab.GoName
205-
r.Queries[i].Args = parseArgs(tab, refs)
211+
if len(outs) == 0 {
206212
r.Queries[i].SQL = rawSQL
207-
case nodes.InsertStmt:
208-
t := tableName(n)
209-
c := columnNames(s, t)
210-
rawSQL, _ := pluckQuery(source, raw)
211-
refs := extractArgs(n)
212-
213-
tab := getTable(s, t)
214-
r.Queries[i].Table = tab
213+
} else if len(outs) == 1 && isStar(outs[0]) {
215214
r.Queries[i].ReturnType = tab.GoName
216-
r.Queries[i].Args = parseArgs(tab, refs)
215+
r.Queries[i].ScanRecord = true
217216
r.Queries[i].SQL = strings.Replace(rawSQL, "*", strings.Join(c, ", "), 1)
218-
default:
219-
log.Printf("%T\n", n)
217+
} else {
218+
r.Queries[i].ReturnType = returnType(tab, outs)
219+
r.Queries[i].SQL = rawSQL
220+
}
221+
}
222+
}
223+
224+
func returnType(t postgres.Table, refs []nodes.ColumnRef) string {
225+
if len(refs) != 1 {
226+
panic("too many return columns")
227+
}
228+
name := join(refs[0].Fields, ".")
229+
for _, c := range t.Columns {
230+
if c.Name == name {
231+
return c.GoType()
220232
}
221233
}
234+
return "interface{}"
222235
}
223236

224237
func extractArgs(n nodes.Node) []paramRef {
@@ -268,6 +281,32 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
268281
ref: n,
269282
})
270283
case nodes.ColumnRef:
284+
case nodes.FuncCall:
285+
case nil:
286+
default:
287+
log.Printf("%T\n", n)
288+
}
289+
return r
290+
}
291+
292+
func findOutputs(r []nodes.ColumnRef, n nodes.Node) []nodes.ColumnRef {
293+
switch n := n.(type) {
294+
case nodes.RawStmt:
295+
r = findOutputs(r, n.Stmt)
296+
case nodes.DeleteStmt:
297+
r = findOutputs(r, n.ReturningList)
298+
case nodes.SelectStmt:
299+
r = findOutputs(r, n.TargetList)
300+
case nodes.InsertStmt:
301+
r = findOutputs(r, n.ReturningList)
302+
case nodes.List:
303+
for _, i := range n.Items {
304+
r = findOutputs(r, i)
305+
}
306+
case nodes.ResTarget:
307+
r = findOutputs(r, n.Val)
308+
case nodes.ColumnRef:
309+
r = append(r, n)
271310
case nil:
272311
default:
273312
log.Printf("%T\n", n)
@@ -414,8 +453,12 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}}
414453
default:
415454
row = q.db.QueryRowContext(ctx, {{.QueryName}}, {{range .Args}}{{.Name}},{{end}})
416455
}
417-
i := {{.Table.GoName}}{}
456+
var i {{.ReturnType}}
457+
{{- if .ScanRecord}}
418458
err := row.Scan({{range .Table.Columns}}&i.{{.GoName}},{{end}})
459+
{{- else}}
460+
err := row.Scan(&i)
461+
{{- end}}
419462
return i, err
420463
}
421464
{{end}}
@@ -436,10 +479,14 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}}
436479
return nil, err
437480
}
438481
defer rows.Close()
439-
items := []{{.Table.GoName}}{}
482+
items := []{{.ReturnType}}{}
440483
for rows.Next() {
441-
i := {{.Table.GoName}}{}
484+
var i {{.ReturnType}}
485+
{{- if .ScanRecord}}
442486
if err := rows.Scan({{range .Table.Columns}}&i.{{.GoName}},{{end}}); err != nil {
487+
{{- else}}
488+
if err := rows.Scan(&i); err != nil {
489+
{{- end}}
443490
return nil, err
444491
}
445492
items = append(items, i)
@@ -525,6 +572,7 @@ func generate(r *Result, pkg string) string {
525572
w.Flush()
526573
code, err := format.Source(b.Bytes())
527574
if err != nil {
575+
fmt.Println(b.String())
528576
panic(fmt.Errorf("source error: %s", err))
529577
}
530578
return string(code)

testdata/ondeck/db.go

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ func Prepare(ctx context.Context, db dbtx) (*Queries, error) {
3838
if q.createCity, err = db.PrepareContext(ctx, createCity); err != nil {
3939
return nil, err
4040
}
41+
if q.createVenue, err = db.PrepareContext(ctx, createVenue); err != nil {
42+
return nil, err
43+
}
4144
if q.deleteVenue, err = db.PrepareContext(ctx, deleteVenue); err != nil {
4245
return nil, err
4346
}
@@ -61,6 +64,7 @@ type Queries struct {
6164

6265
tx *sql.Tx
6366
createCity *sql.Stmt
67+
createVenue *sql.Stmt
6468
deleteVenue *sql.Stmt
6569
getCity *sql.Stmt
6670
getVenue *sql.Stmt
@@ -73,6 +77,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
7377
tx: tx,
7478
db: tx,
7579
createCity: q.createCity,
80+
createVenue: q.createVenue,
7681
deleteVenue: q.deleteVenue,
7782
getCity: q.getCity,
7883
getVenue: q.getVenue,
@@ -101,11 +106,42 @@ func (q *Queries) CreateCity(ctx context.Context, name string, slug string) (Cit
101106
default:
102107
row = q.db.QueryRowContext(ctx, createCity, name, slug)
103108
}
104-
i := City{}
109+
var i City
105110
err := row.Scan(&i.Slug, &i.Name)
106111
return i, err
107112
}
108113

114+
const createVenue = `-- name: CreateVenue :one
115+
INSERT INTO venue (
116+
name,
117+
slug,
118+
created_at,
119+
spotify_playlist,
120+
city
121+
) VALUES (
122+
$1,
123+
$2,
124+
NOW(),
125+
$3,
126+
$4
127+
) RETURNING id
128+
`
129+
130+
func (q *Queries) CreateVenue(ctx context.Context, name string, slug string, spotify_playlist string, city string) (int, error) {
131+
var row *sql.Row
132+
switch {
133+
case q.createVenue != nil && q.tx != nil:
134+
row = q.tx.StmtContext(ctx, q.createVenue).QueryRowContext(ctx, name, slug, spotify_playlist, city)
135+
case q.createVenue != nil:
136+
row = q.createVenue.QueryRowContext(ctx, name, slug, spotify_playlist, city)
137+
default:
138+
row = q.db.QueryRowContext(ctx, createVenue, name, slug, spotify_playlist, city)
139+
}
140+
var i int
141+
err := row.Scan(&i)
142+
return i, err
143+
}
144+
109145
const deleteVenue = `-- name: DeleteVenue :exec
110146
DELETE FROM venue
111147
WHERE slug = $1
@@ -140,7 +176,7 @@ func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
140176
default:
141177
row = q.db.QueryRowContext(ctx, getCity, slug)
142178
}
143-
i := City{}
179+
var i City
144180
err := row.Scan(&i.Slug, &i.Name)
145181
return i, err
146182
}
@@ -161,7 +197,7 @@ func (q *Queries) GetVenue(ctx context.Context, slug string, city string) (Venue
161197
default:
162198
row = q.db.QueryRowContext(ctx, getVenue, slug, city)
163199
}
164-
i := Venue{}
200+
var i Venue
165201
err := row.Scan(&i.ID, &i.CreatedAt, &i.Slug, &i.Name, &i.City, &i.SpotifyPlaylist, &i.SongkickID)
166202
return i, err
167203
}
@@ -189,7 +225,7 @@ func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
189225
defer rows.Close()
190226
items := []City{}
191227
for rows.Next() {
192-
i := City{}
228+
var i City
193229
if err := rows.Scan(&i.Slug, &i.Name); err != nil {
194230
return nil, err
195231
}
@@ -228,7 +264,7 @@ func (q *Queries) ListVenues(ctx context.Context, city string) ([]Venue, error)
228264
defer rows.Close()
229265
items := []Venue{}
230266
for rows.Next() {
231-
i := Venue{}
267+
var i Venue
232268
if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Slug, &i.Name, &i.City, &i.SpotifyPlaylist, &i.SongkickID); err != nil {
233269
return nil, err
234270
}

testdata/ondeck/query/queries.sql

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ INSERT INTO city (
3131
$1,
3232
$2
3333
) RETURNING *;
34+
35+
-- name: CreateVenue :one
36+
INSERT INTO venue (
37+
name,
38+
slug,
39+
created_at,
40+
spotify_playlist,
41+
city
42+
) VALUES (
43+
$1,
44+
$2,
45+
NOW(),
46+
$3,
47+
$4
48+
) RETURNING id;

0 commit comments

Comments
 (0)