Skip to content

Commit f386e3c

Browse files
committed
First working version
1 parent 1cbc5f4 commit f386e3c

File tree

5 files changed

+123
-21
lines changed

5 files changed

+123
-21
lines changed

cmd/strongdb/main.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"log"
6+
7+
"github.com/kyleconroy/strongdb"
8+
)
9+
10+
func main() {
11+
pkg := flag.String("package", "db", "package name for Go code")
12+
sch := flag.String("schema", "", "input directory of SQL migrations")
13+
out := flag.String("out", "db.go", "output file")
14+
flag.Parse()
15+
16+
if err := strongdb.Exec(*sch, flag.Arg(0), *pkg, *out); err != nil {
17+
log.Fatal(err)
18+
}
19+
}

exec.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package strongdb
2+
3+
import (
4+
"io/ioutil"
5+
)
6+
7+
func Exec(schemaDir, queryDir, pkg, out string) error {
8+
s, err := ParseSchmea(schemaDir)
9+
if err != nil {
10+
return err
11+
}
12+
13+
q, err := ParseQueries(s, queryDir)
14+
if err != nil {
15+
return err
16+
}
17+
18+
source := generate(q, pkg)
19+
return ioutil.WriteFile(out, []byte(source), 0644)
20+
}

parser.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io/ioutil"
99
"log"
1010
"path/filepath"
11+
"sort"
1112
"strings"
1213
"text/template"
1314
"unicode"
@@ -186,6 +187,7 @@ func parseFuncs(s *postgres.Schema, r *Result, tree pg.ParsetreeList) {
186187
func where(q *sq.SelectBuilder, n nodes.SelectStmt, args []string) (*sq.SelectBuilder, []string) {
187188
// Only equality supported
188189
eq := sq.Eq{}
190+
found := false
189191
switch a := n.WhereClause.(type) {
190192
case nodes.A_Expr:
191193
switch n := a.Lexpr.(type) {
@@ -197,13 +199,17 @@ func where(q *sq.SelectBuilder, n nodes.SelectStmt, args []string) (*sq.SelectBu
197199
key += n.Str
198200
}
199201
}
202+
found = true
200203
args = append(args, key)
201204
eq[key] = "?"
202205
}
203206
// switch n := a.Lexpr.(type) {
204207
// case nodes.ParamRef:
205208
// }
206209
}
210+
if !found {
211+
return q, args
212+
}
207213
return q.Where(eq), args
208214
}
209215

@@ -280,6 +286,20 @@ type dbtx interface {
280286
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
281287
}
282288
289+
func New(db dbtx) *Queries {
290+
return &Queries{db: db}
291+
}
292+
293+
func Prepare(ctx context.Context, db dbtx) (*Queries, error) {
294+
q := Queries{db: db}
295+
var err error{{range .Queries}}
296+
if q.{{.StmtName}}, err = db.PrepareContext(ctx, {{.QueryName}}); err != nil {
297+
return nil, err
298+
}
299+
{{- end}}
300+
return &q, nil
301+
}
302+
283303
type Queries struct {
284304
db dbtx
285305
@@ -289,6 +309,16 @@ type Queries struct {
289309
{{- end}}
290310
}
291311
312+
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
313+
return &Queries{
314+
tx: tx,
315+
db: tx,
316+
{{- range .Queries}}
317+
{{.StmtName}}: q.{{.StmtName}},
318+
{{- end}}
319+
}
320+
}
321+
292322
{{range .Queries}}
293323
const {{.QueryName}} = {{$.Q}}
294324
{{.SQL}}
@@ -382,6 +412,8 @@ func lowerTitle(s string) string {
382412
}
383413

384414
func generate(r *Result, pkg string) string {
415+
sort.Slice(r.Queries, func(i, j int) bool { return r.Queries[i].MethodName < r.Queries[j].MethodName })
416+
385417
funcMap := template.FuncMap{
386418
"lowerTitle": lowerTitle,
387419
}

parser_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package strongdb
22

33
import (
4+
"fmt"
45
"io/ioutil"
56
"log"
67
"path/filepath"
@@ -27,5 +28,6 @@ func TestParseSchema(t *testing.T) {
2728

2829
if source != string(blob) {
2930
t.Errorf("output differs")
31+
fmt.Println(source)
3032
}
3133
}

testdata/ondeck/db.go

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,65 @@ type dbtx interface {
2525
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
2626
}
2727

28+
func New(db dbtx) *Queries {
29+
return &Queries{db: db}
30+
}
31+
32+
func Prepare(ctx context.Context, db dbtx) (*Queries, error) {
33+
q := Queries{db: db}
34+
var err error
35+
if q.getCity, err = db.PrepareContext(ctx, getCity); err != nil {
36+
return nil, err
37+
}
38+
if q.listCities, err = db.PrepareContext(ctx, listCities); err != nil {
39+
return nil, err
40+
}
41+
if q.listVenues, err = db.PrepareContext(ctx, listVenues); err != nil {
42+
return nil, err
43+
}
44+
return &q, nil
45+
}
46+
2847
type Queries struct {
2948
db dbtx
3049

3150
tx *sql.Tx
32-
listCities *sql.Stmt
3351
getCity *sql.Stmt
52+
listCities *sql.Stmt
3453
listVenues *sql.Stmt
3554
}
3655

56+
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
57+
return &Queries{
58+
tx: tx,
59+
db: tx,
60+
getCity: q.getCity,
61+
listCities: q.listCities,
62+
listVenues: q.listVenues,
63+
}
64+
}
65+
66+
const getCity = `
67+
SELECT slug, name FROM city WHERE slug = $1
68+
`
69+
70+
func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
71+
var row *sql.Row
72+
switch {
73+
case q.getCity != nil && q.tx != nil:
74+
row = q.tx.StmtContext(ctx, q.getCity).QueryRowContext(ctx, slug)
75+
case q.getCity != nil:
76+
row = q.getCity.QueryRowContext(ctx, slug)
77+
default:
78+
row = q.db.QueryRowContext(ctx, getCity, slug)
79+
}
80+
i := City{}
81+
err := row.Scan(&i.Slug, &i.Name)
82+
return i, err
83+
}
84+
3785
const listCities = `
38-
SELECT slug, name FROM city WHERE ORDER BY name
86+
SELECT slug, name FROM city ORDER BY name
3987
`
4088

4189
func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
@@ -70,25 +118,6 @@ func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
70118
return items, nil
71119
}
72120

73-
const getCity = `
74-
SELECT slug, name FROM city WHERE slug = $1
75-
`
76-
77-
func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
78-
var row *sql.Row
79-
switch {
80-
case q.getCity != nil && q.tx != nil:
81-
row = q.tx.StmtContext(ctx, q.getCity).QueryRowContext(ctx, slug)
82-
case q.getCity != nil:
83-
row = q.getCity.QueryRowContext(ctx, slug)
84-
default:
85-
row = q.db.QueryRowContext(ctx, getCity, slug)
86-
}
87-
i := City{}
88-
err := row.Scan(&i.Slug, &i.Name)
89-
return i, err
90-
}
91-
92121
const listVenues = `
93122
SELECT slug, name, city, spotify_playlist, songkick_id FROM venue WHERE city = $1 ORDER BY name
94123
`

0 commit comments

Comments
 (0)