Skip to content

Commit c35def9

Browse files
committed
Add update support
1 parent 922f5f3 commit c35def9

File tree

3 files changed

+113
-39
lines changed

3 files changed

+113
-39
lines changed

parser.go

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
192192
case nodes.SelectStmt:
193193
case nodes.DeleteStmt:
194194
case nodes.InsertStmt:
195+
case nodes.UpdateStmt:
195196
default:
196197
log.Printf("%T\n", n)
197198
continue
@@ -250,14 +251,15 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
250251
n = parent
251252
}
252253
switch n := n.(type) {
253-
case nodes.RawStmt:
254-
r = findRefs(r, n.Stmt, nil)
254+
case nodes.A_Expr:
255+
r = findRefs(r, n, n.Lexpr)
256+
r = findRefs(r, n, n.Rexpr)
257+
case nodes.ColumnRef:
258+
case nodes.BoolExpr:
259+
r = findRefs(r, n.Args, nil)
255260
case nodes.DeleteStmt:
256261
r = findRefs(r, n.WhereClause, nil)
257-
case nodes.SelectStmt:
258-
r = findRefs(r, n.WhereClause, nil)
259-
r = findRefs(r, n.LimitCount, nil)
260-
r = findRefs(r, n.LimitOffset, nil)
262+
case nodes.FuncCall:
261263
case nodes.InsertStmt:
262264
switch s := n.SelectStmt.(type) {
263265
case nodes.SelectStmt:
@@ -268,20 +270,26 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
268270
}
269271
}
270272
}
271-
case nodes.BoolExpr:
272-
for _, item := range n.Args.Items {
273+
case nodes.List:
274+
for _, item := range n.Items {
273275
r = findRefs(r, item, nil)
274276
}
275-
case nodes.A_Expr:
276-
r = findRefs(r, n, n.Lexpr)
277-
r = findRefs(r, n, n.Rexpr)
278277
case nodes.ParamRef:
279278
r = append(r, paramRef{
280279
parent: parent,
281280
ref: n,
282281
})
283-
case nodes.ColumnRef:
284-
case nodes.FuncCall:
282+
case nodes.RawStmt:
283+
r = findRefs(r, n.Stmt, nil)
284+
case nodes.ResTarget:
285+
r = findRefs(r, n, n.Val)
286+
case nodes.SelectStmt:
287+
r = findRefs(r, n.WhereClause, nil)
288+
r = findRefs(r, n.LimitCount, nil)
289+
r = findRefs(r, n.LimitOffset, nil)
290+
case nodes.UpdateStmt:
291+
r = findRefs(r, n.TargetList, nil)
292+
r = findRefs(r, n.WhereClause, nil)
285293
case nil:
286294
default:
287295
log.Printf("%T\n", n)
@@ -291,22 +299,24 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
291299

292300
func findOutputs(r []nodes.ColumnRef, n nodes.Node) []nodes.ColumnRef {
293301
switch n := n.(type) {
294-
case nodes.RawStmt:
295-
r = findOutputs(r, n.Stmt)
302+
case nodes.ColumnRef:
303+
r = append(r, n)
296304
case nodes.DeleteStmt:
297305
r = findOutputs(r, n.ReturningList)
298-
case nodes.SelectStmt:
299-
r = findOutputs(r, n.TargetList)
300306
case nodes.InsertStmt:
301307
r = findOutputs(r, n.ReturningList)
302308
case nodes.List:
303309
for _, i := range n.Items {
304310
r = findOutputs(r, i)
305311
}
312+
case nodes.RawStmt:
313+
r = findOutputs(r, n.Stmt)
306314
case nodes.ResTarget:
307315
r = findOutputs(r, n.Val)
308-
case nodes.ColumnRef:
309-
r = append(r, n)
316+
case nodes.SelectStmt:
317+
r = findOutputs(r, n.TargetList)
318+
case nodes.UpdateStmt:
319+
r = findOutputs(r, n.ReturningList)
310320
case nil:
311321
default:
312322
log.Printf("%T\n", n)
@@ -369,16 +379,18 @@ func columnNames(s *postgres.Schema, table string) []string {
369379

370380
func tableName(n nodes.Node) string {
371381
switch n := n.(type) {
382+
case nodes.DeleteStmt:
383+
return *n.Relation.Relname
384+
case nodes.InsertStmt:
385+
return *n.Relation.Relname
372386
case nodes.SelectStmt:
373387
for _, item := range n.FromClause.Items {
374388
switch i := item.(type) {
375389
case nodes.RangeVar:
376390
return *i.Relname
377391
}
378392
}
379-
case nodes.DeleteStmt:
380-
return *n.Relation.Relname
381-
case nodes.InsertStmt:
393+
case nodes.UpdateStmt:
382394
return *n.Relation.Relname
383395
}
384396
return ""

testdata/ondeck/db.go

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,43 @@ func Prepare(ctx context.Context, db dbtx) (*Queries, error) {
5656
if q.listVenues, err = db.PrepareContext(ctx, listVenues); err != nil {
5757
return nil, err
5858
}
59+
if q.updateCityName, err = db.PrepareContext(ctx, updateCityName); err != nil {
60+
return nil, err
61+
}
62+
if q.updateVenueName, err = db.PrepareContext(ctx, updateVenueName); err != nil {
63+
return nil, err
64+
}
5965
return &q, nil
6066
}
6167

6268
type Queries struct {
6369
db dbtx
6470

65-
tx *sql.Tx
66-
createCity *sql.Stmt
67-
createVenue *sql.Stmt
68-
deleteVenue *sql.Stmt
69-
getCity *sql.Stmt
70-
getVenue *sql.Stmt
71-
listCities *sql.Stmt
72-
listVenues *sql.Stmt
71+
tx *sql.Tx
72+
createCity *sql.Stmt
73+
createVenue *sql.Stmt
74+
deleteVenue *sql.Stmt
75+
getCity *sql.Stmt
76+
getVenue *sql.Stmt
77+
listCities *sql.Stmt
78+
listVenues *sql.Stmt
79+
updateCityName *sql.Stmt
80+
updateVenueName *sql.Stmt
7381
}
7482

7583
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
7684
return &Queries{
77-
tx: tx,
78-
db: tx,
79-
createCity: q.createCity,
80-
createVenue: q.createVenue,
81-
deleteVenue: q.deleteVenue,
82-
getCity: q.getCity,
83-
getVenue: q.getVenue,
84-
listCities: q.listCities,
85-
listVenues: q.listVenues,
85+
tx: tx,
86+
db: tx,
87+
createCity: q.createCity,
88+
createVenue: q.createVenue,
89+
deleteVenue: q.deleteVenue,
90+
getCity: q.getCity,
91+
getVenue: q.getVenue,
92+
listCities: q.listCities,
93+
listVenues: q.listVenues,
94+
updateCityName: q.updateCityName,
95+
updateVenueName: q.updateVenueName,
8696
}
8797
}
8898

@@ -278,3 +288,44 @@ func (q *Queries) ListVenues(ctx context.Context, city string) ([]Venue, error)
278288
}
279289
return items, nil
280290
}
291+
292+
const updateCityName = `-- name: UpdateCityName :exec
293+
UPDATE city
294+
SET name = $2
295+
WHERE slug = $1
296+
`
297+
298+
func (q *Queries) UpdateCityName(ctx context.Context, slug string, name string) error {
299+
var err error
300+
switch {
301+
case q.updateCityName != nil && q.tx != nil:
302+
_, err = q.tx.StmtContext(ctx, q.updateCityName).ExecContext(ctx, slug, name)
303+
case q.updateCityName != nil:
304+
_, err = q.updateCityName.ExecContext(ctx, slug, name)
305+
default:
306+
_, err = q.db.ExecContext(ctx, updateCityName, slug, name)
307+
}
308+
return err
309+
}
310+
311+
const updateVenueName = `-- name: UpdateVenueName :one
312+
UPDATE venue
313+
SET name = $2
314+
WHERE slug = $1
315+
RETURNING id
316+
`
317+
318+
func (q *Queries) UpdateVenueName(ctx context.Context, slug string, name string) (int, error) {
319+
var row *sql.Row
320+
switch {
321+
case q.updateVenueName != nil && q.tx != nil:
322+
row = q.tx.StmtContext(ctx, q.updateVenueName).QueryRowContext(ctx, slug, name)
323+
case q.updateVenueName != nil:
324+
row = q.updateVenueName.QueryRowContext(ctx, slug, name)
325+
default:
326+
row = q.db.QueryRowContext(ctx, updateVenueName, slug, name)
327+
}
328+
var i int
329+
err := row.Scan(&i)
330+
return i, err
331+
}

testdata/ondeck/query/queries.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ INSERT INTO venue (
4646
$3,
4747
$4
4848
) RETURNING id;
49+
50+
-- name: UpdateCityName :exec
51+
UPDATE city
52+
SET name = $2
53+
WHERE slug = $1;
54+
55+
-- name: UpdateVenueName :one
56+
UPDATE venue
57+
SET name = $2
58+
WHERE slug = $1
59+
RETURNING id;

0 commit comments

Comments
 (0)