Skip to content

Commit d24b1db

Browse files
committed
Add support for insert stmts
1 parent 3ac9041 commit d24b1db

File tree

8 files changed

+153
-11
lines changed

8 files changed

+153
-11
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ go 1.12
44

55
require (
66
github.com/davecgh/go-spew v1.1.1
7+
github.com/google/go-cmp v0.3.0
78
github.com/lfittl/pg_query_go v1.0.0
89
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
44
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
55
github.com/elgris/sqrl v0.0.0-20181124135704-90ecf730640a h1:VRAv/FIe+jL6t/kHB43hrwhUIyP2/cq+vC/9bdS8v4o=
66
github.com/elgris/sqrl v0.0.0-20181124135704-90ecf730640a/go.mod h1:hQPgqeM4LmbfKCaBkcedRq5y1yfb8Qb8iYdbuNjE4FU=
7+
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
8+
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
79
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
810
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
911
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=

parser.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) {
5959
for _, elt := range n.TableElts.Items {
6060
switch n := elt.(type) {
6161
case nodes.ColumnDef:
62-
// spew.Dump(n)
6362
// log.Printf("not null: %t", n.IsNotNull)
6463
table.Columns = append(table.Columns, postgres.Column{
6564
Name: *n.Colname,
65+
Type: join(n.TypeName.Names, "."),
6666
GoName: structName(*n.Colname),
6767
NotNull: isNotNull(n),
6868
})
@@ -75,6 +75,16 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) {
7575
}
7676
}
7777

78+
func join(list nodes.List, sep string) string {
79+
items := []string{}
80+
for _, item := range list.Items {
81+
if n, ok := item.(nodes.String); ok {
82+
items = append(items, n.Str)
83+
}
84+
}
85+
return strings.Join(items, sep)
86+
}
87+
7888
func isNotNull(n nodes.ColumnDef) bool {
7989
if n.IsNotNull {
8090
return true
@@ -101,6 +111,7 @@ type Query struct {
101111
SQL string
102112
Args []Arg
103113
Table postgres.Table
114+
ReturnType string
104115
}
105116

106117
type Result struct {
@@ -179,6 +190,7 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
179190

180191
tab := getTable(s, t)
181192
r.Queries[i].Table = tab
193+
r.Queries[i].ReturnType = tab.GoName
182194
r.Queries[i].Args = parseArgs(tab, refs)
183195
r.Queries[i].SQL = strings.Replace(rawSQL, "*", strings.Join(c, ", "), 1)
184196
case nodes.DeleteStmt:
@@ -189,8 +201,20 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
189201

190202
tab := getTable(s, t)
191203
r.Queries[i].Table = tab
204+
r.Queries[i].ReturnType = tab.GoName
192205
r.Queries[i].Args = parseArgs(tab, refs)
193206
r.Queries[i].SQL = rawSQL
207+
case nodes.InsertStmt:
208+
t := tableName(n)
209+
c := columnNames(s, t)
210+
rawSQL, _ := pluckQuery(source, raw)
211+
refs := extractArgs(n)
212+
213+
tab := getTable(s, t)
214+
r.Queries[i].Table = tab
215+
r.Queries[i].ReturnType = tab.GoName
216+
r.Queries[i].Args = parseArgs(tab, refs)
217+
r.Queries[i].SQL = strings.Replace(rawSQL, "*", strings.Join(c, ", "), 1)
194218
default:
195219
log.Printf("%T\n", n)
196220
}
@@ -221,6 +245,16 @@ func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
221245
r = findRefs(r, n.WhereClause, nil)
222246
r = findRefs(r, n.LimitCount, nil)
223247
r = findRefs(r, n.LimitOffset, nil)
248+
case nodes.InsertStmt:
249+
switch s := n.SelectStmt.(type) {
250+
case nodes.SelectStmt:
251+
for _, vl := range s.ValuesLists {
252+
for i, v := range vl {
253+
// TODO: Index error
254+
r = findRefs(r, n.Cols.Items[i], v)
255+
}
256+
}
257+
}
224258
case nodes.BoolExpr:
225259
for _, item := range n.Args.Items {
226260
r = findRefs(r, item, nil)
@@ -265,6 +299,15 @@ func parseArgs(t postgres.Table, args []paramRef) []Arg {
265299
panic("unknown column: " + key)
266300
}
267301
}
302+
case nodes.ResTarget:
303+
key := *n.Name
304+
if typ, ok := typeMap[key]; ok {
305+
a = append(a, Arg{Name: key, Type: typ})
306+
} else {
307+
panic("unknown column: " + key)
308+
}
309+
case nodes.ParamRef:
310+
a = append(a, Arg{Name: "_", Type: "interface{}"})
268311
default:
269312
panic(fmt.Sprintf("unsupported type: %T", n))
270313
}
@@ -296,6 +339,8 @@ func tableName(n nodes.Node) string {
296339
}
297340
case nodes.DeleteStmt:
298341
return *n.Relation.Relname
342+
case nodes.InsertStmt:
343+
return *n.Relation.Relname
299344
}
300345
return ""
301346
}
@@ -304,11 +349,12 @@ var hh = `package {{.Package}}
304349
import (
305350
"context"
306351
"database/sql"
352+
"time"
307353
)
308354
309355
{{range .Schema.Tables}}
310356
type {{.GoName}} struct { {{- range .Columns}}
311-
{{.GoName}} {{if .NotNull }}string{{else}}sql.NullString{{end}}
357+
{{.GoName}} {{.GoType}}
312358
{{- end}}
313359
}
314360
{{end}}
@@ -358,7 +404,7 @@ const {{.QueryName}} = {{$.Q}}{{.SQL}}
358404
{{$.Q}}
359405
360406
{{if eq .Type ":one"}}
361-
func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}} {{.Type}},{{end}}) ({{.Table.GoName}}, error) {
407+
func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}} {{.Type}},{{end}}) ({{.ReturnType}}, error) {
362408
var row *sql.Row
363409
switch {
364410
case q.{{.StmtName}} != nil && q.tx != nil:
@@ -375,7 +421,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}}
375421
{{end}}
376422
377423
{{if eq .Type ":many"}}
378-
func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}} {{.Type}},{{end}}) ([]{{.Table.GoName}}, error) {
424+
func (q *Queries) {{.MethodName}}(ctx context.Context, {{range .Args}}{{.Name}} {{.Type}},{{end}}) ([]{{.ReturnType}}, error) {
379425
var rows *sql.Rows
380426
var err error
381427
switch {

parser_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package strongdb
22

33
import (
4-
"fmt"
54
"io/ioutil"
65
"log"
76
"path/filepath"
87
"testing"
98

9+
"github.com/google/go-cmp/cmp"
1010
pg "github.com/lfittl/pg_query_go"
1111
nodes "github.com/lfittl/pg_query_go/nodes"
1212
)
@@ -88,8 +88,8 @@ func TestParseSchema(t *testing.T) {
8888
log.Fatal(err)
8989
}
9090

91-
if source != string(blob) {
92-
t.Errorf("output differs")
93-
fmt.Println(source)
91+
if diff := cmp.Diff(source, string(blob)); diff != "" {
92+
t.Errorf("genreated code differed (-want +got):\n%s", diff)
93+
t.Log(source)
9494
}
9595
}

postgres/schema.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,21 @@ type Column struct {
1616
Type string
1717
NotNull bool
1818
}
19+
20+
func (c Column) GoType() string {
21+
// {{.GoName}} {{if .NotNull }}string{{else}}sql.NullString{{end}}
22+
switch c.Type {
23+
case "text":
24+
if c.NotNull {
25+
return "string"
26+
} else {
27+
return "sql.NullString"
28+
}
29+
case "serial":
30+
return "int"
31+
case "pg_catalog.timestamp":
32+
return "time.Time"
33+
default:
34+
return "interface{}"
35+
}
36+
}

testdata/ondeck/db.go

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ondeck
33
import (
44
"context"
55
"database/sql"
6+
"time"
67
)
78

89
type City struct {
@@ -11,6 +12,8 @@ type City struct {
1112
}
1213

1314
type Venue struct {
15+
ID int
16+
CreatedAt time.Time
1417
Slug string
1518
Name string
1619
City sql.NullString
@@ -32,12 +35,18 @@ func New(db dbtx) *Queries {
3235
func Prepare(ctx context.Context, db dbtx) (*Queries, error) {
3336
q := Queries{db: db}
3437
var err error
38+
if q.createCity, err = db.PrepareContext(ctx, createCity); err != nil {
39+
return nil, err
40+
}
3541
if q.deleteVenue, err = db.PrepareContext(ctx, deleteVenue); err != nil {
3642
return nil, err
3743
}
3844
if q.getCity, err = db.PrepareContext(ctx, getCity); err != nil {
3945
return nil, err
4046
}
47+
if q.getVenue, err = db.PrepareContext(ctx, getVenue); err != nil {
48+
return nil, err
49+
}
4150
if q.listCities, err = db.PrepareContext(ctx, listCities); err != nil {
4251
return nil, err
4352
}
@@ -51,8 +60,10 @@ type Queries struct {
5160
db dbtx
5261

5362
tx *sql.Tx
63+
createCity *sql.Stmt
5464
deleteVenue *sql.Stmt
5565
getCity *sql.Stmt
66+
getVenue *sql.Stmt
5667
listCities *sql.Stmt
5768
listVenues *sql.Stmt
5869
}
@@ -61,13 +72,40 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
6172
return &Queries{
6273
tx: tx,
6374
db: tx,
75+
createCity: q.createCity,
6476
deleteVenue: q.deleteVenue,
6577
getCity: q.getCity,
78+
getVenue: q.getVenue,
6679
listCities: q.listCities,
6780
listVenues: q.listVenues,
6881
}
6982
}
7083

84+
const createCity = `-- name: CreateCity :one
85+
INSERT INTO city (
86+
name,
87+
slug
88+
) VALUES (
89+
$1,
90+
$2
91+
) RETURNING slug, name
92+
`
93+
94+
func (q *Queries) CreateCity(ctx context.Context, name string, slug string) (City, error) {
95+
var row *sql.Row
96+
switch {
97+
case q.createCity != nil && q.tx != nil:
98+
row = q.tx.StmtContext(ctx, q.createCity).QueryRowContext(ctx, name, slug)
99+
case q.createCity != nil:
100+
row = q.createCity.QueryRowContext(ctx, name, slug)
101+
default:
102+
row = q.db.QueryRowContext(ctx, createCity, name, slug)
103+
}
104+
i := City{}
105+
err := row.Scan(&i.Slug, &i.Name)
106+
return i, err
107+
}
108+
71109
const deleteVenue = `-- name: DeleteVenue :exec
72110
DELETE FROM venue
73111
WHERE slug = $1
@@ -107,6 +145,27 @@ func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
107145
return i, err
108146
}
109147

148+
const getVenue = `-- name: GetVenue :one
149+
SELECT id, created_at, slug, name, city, spotify_playlist, songkick_id
150+
FROM venue
151+
WHERE slug = $1 AND city = $2
152+
`
153+
154+
func (q *Queries) GetVenue(ctx context.Context, slug string, city string) (Venue, error) {
155+
var row *sql.Row
156+
switch {
157+
case q.getVenue != nil && q.tx != nil:
158+
row = q.tx.StmtContext(ctx, q.getVenue).QueryRowContext(ctx, slug, city)
159+
case q.getVenue != nil:
160+
row = q.getVenue.QueryRowContext(ctx, slug, city)
161+
default:
162+
row = q.db.QueryRowContext(ctx, getVenue, slug, city)
163+
}
164+
i := Venue{}
165+
err := row.Scan(&i.ID, &i.CreatedAt, &i.Slug, &i.Name, &i.City, &i.SpotifyPlaylist, &i.SongkickID)
166+
return i, err
167+
}
168+
110169
const listCities = `-- name: ListCities :many
111170
SELECT slug, name
112171
FROM city
@@ -146,7 +205,7 @@ func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
146205
}
147206

148207
const listVenues = `-- name: ListVenues :many
149-
SELECT slug, name, city, spotify_playlist, songkick_id
208+
SELECT id, created_at, slug, name, city, spotify_playlist, songkick_id
150209
FROM venue
151210
WHERE city = $1
152211
ORDER BY name
@@ -170,7 +229,7 @@ func (q *Queries) ListVenues(ctx context.Context, city string) ([]Venue, error)
170229
items := []Venue{}
171230
for rows.Next() {
172231
i := Venue{}
173-
if err := rows.Scan(&i.Slug, &i.Name, &i.City, &i.SpotifyPlaylist, &i.SongkickID); err != nil {
232+
if err := rows.Scan(&i.ID, &i.CreatedAt, &i.Slug, &i.Name, &i.City, &i.SpotifyPlaylist, &i.SongkickID); err != nil {
174233
return nil, err
175234
}
176235
items = append(items, i)

testdata/ondeck/query/queries.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,17 @@ ORDER BY name;
1717
-- name: DeleteVenue :exec
1818
DELETE FROM venue
1919
WHERE slug = $1;
20+
21+
-- name: GetVenue :one
22+
SELECT *
23+
FROM venue
24+
WHERE slug = $1 AND city = $2;
25+
26+
-- name: CreateCity :one
27+
INSERT INTO city (
28+
name,
29+
slug
30+
) VALUES (
31+
$1,
32+
$2
33+
) RETURNING *;

testdata/ondeck/schema/0002_venue.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
CREATE TABLE venue (
2-
slug text primary key,
2+
id SERIAL primary key,
3+
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
4+
slug text not null,
35
name text not null,
46
city text references city(slug),
57
spotify_playlist text not null,

0 commit comments

Comments
 (0)