Skip to content

Commit 6c049b7

Browse files
committed
Remove sqrl
1 parent f386e3c commit 6c049b7

File tree

4 files changed

+148
-64
lines changed

4 files changed

+148
-64
lines changed

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@ go 1.12
44

55
require (
66
github.com/davecgh/go-spew v1.1.1
7-
github.com/elgris/sqrl v0.0.0-20181124135704-90ecf730640a
87
github.com/lfittl/pg_query_go v1.0.0
98
)

parser.go

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"unicode"
1515

1616
"github.com/davecgh/go-spew/spew"
17-
sq "github.com/elgris/sqrl"
1817
"github.com/kyleconroy/strongdb/postgres"
1918
pg "github.com/lfittl/pg_query_go"
2019
nodes "github.com/lfittl/pg_query_go/nodes"
@@ -133,7 +132,7 @@ func ParseQueries(s *postgres.Schema, dir string) (*Result, error) {
133132
if err != nil {
134133
return nil, err
135134
}
136-
parseFuncs(s, &r, tree)
135+
parseFuncs(s, &r, string(blob), tree)
137136
return &r, nil
138137
}
139138
return nil, nil
@@ -156,88 +155,106 @@ func parseQueries(t []byte) []Query {
156155
return q
157156
}
158157

159-
func parseFuncs(s *postgres.Schema, r *Result, tree pg.ParsetreeList) {
158+
func pluckQuery(source string, n nodes.RawStmt) (string, error) {
159+
// TODO: Bounds checking
160+
head := n.StmtLocation
161+
tail := n.StmtLocation + n.StmtLen
162+
return strings.TrimSpace(source[head:tail]), nil
163+
}
164+
165+
func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeList) {
160166
for i, stmt := range tree.Statements {
161167
raw, ok := stmt.(nodes.RawStmt)
162168
if !ok {
163169
continue
164170
}
171+
165172
switch n := raw.Stmt.(type) {
166173
case nodes.SelectStmt:
167174
t := tableName(n)
168-
169175
c := columnNames(s, t)
170-
args := []string{}
171-
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
172-
q := psql.Select(c...).From(t)
173-
q, args = where(q, n, args)
174-
q = orderBy(q, n)
175-
query, _, _ := q.ToSql()
176+
177+
rawSQL, _ := pluckQuery(source, raw)
178+
refs := extractArgs(n)
176179

177180
tab := getTable(s, t)
178181
r.Queries[i].Table = tab
179-
r.Queries[i].Args = parseArgs(tab, args)
180-
r.Queries[i].SQL = query
182+
r.Queries[i].Args = parseArgs(tab, refs)
183+
r.Queries[i].SQL = strings.Replace(rawSQL, "*", strings.Join(c, ", "), 1)
181184
default:
182185
log.Printf("%T\n", n)
183186
}
184187
}
185188
}
186189

187-
func where(q *sq.SelectBuilder, n nodes.SelectStmt, args []string) (*sq.SelectBuilder, []string) {
188-
// Only equality supported
189-
eq := sq.Eq{}
190-
found := false
191-
switch a := n.WhereClause.(type) {
192-
case nodes.A_Expr:
193-
switch n := a.Lexpr.(type) {
194-
case nodes.ColumnRef:
195-
key := ""
196-
for _, n := range n.Fields.Items {
197-
switch n := n.(type) {
198-
case nodes.String:
199-
key += n.Str
200-
}
201-
}
202-
found = true
203-
args = append(args, key)
204-
eq[key] = "?"
205-
}
206-
// switch n := a.Lexpr.(type) {
207-
// case nodes.ParamRef:
208-
// }
190+
func extractArgs(n nodes.Node) []paramRef {
191+
refs := findRefs([]paramRef{}, n, nil)
192+
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
193+
return refs
194+
}
195+
196+
type paramRef struct {
197+
parent nodes.Node
198+
ref nodes.ParamRef
199+
}
200+
201+
func findRefs(r []paramRef, parent, n nodes.Node) []paramRef {
202+
if n == nil {
203+
n = parent
209204
}
210-
if !found {
211-
return q, args
205+
switch n := n.(type) {
206+
case nodes.RawStmt:
207+
r = findRefs(r, n.Stmt, nil)
208+
case nodes.SelectStmt:
209+
r = findRefs(r, n.WhereClause, nil)
210+
r = findRefs(r, n.LimitCount, nil)
211+
r = findRefs(r, n.LimitOffset, nil)
212+
case nodes.BoolExpr:
213+
for _, item := range n.Args.Items {
214+
r = findRefs(r, item, nil)
215+
}
216+
case nodes.A_Expr:
217+
r = findRefs(r, n, n.Lexpr)
218+
r = findRefs(r, n, n.Rexpr)
219+
case nodes.ParamRef:
220+
r = append(r, paramRef{
221+
parent: parent,
222+
ref: n,
223+
})
224+
case nodes.ColumnRef:
225+
case nil:
226+
default:
227+
log.Printf("%T\n", n)
212228
}
213-
return q.Where(eq), args
229+
return r
214230
}
215231

216-
func orderBy(q *sq.SelectBuilder, n nodes.SelectStmt) *sq.SelectBuilder {
217-
for _, n := range n.SortClause.Items {
218-
switch n := n.(type) {
219-
case nodes.SortBy:
220-
switch n := n.Node.(type) {
232+
func parseArgs(t postgres.Table, args []paramRef) []Arg {
233+
typeMap := map[string]string{}
234+
for _, c := range t.Columns {
235+
typeMap[c.Name] = "string"
236+
}
237+
a := []Arg{}
238+
for _, ref := range args {
239+
switch n := ref.parent.(type) {
240+
case nodes.A_Expr:
241+
switch n := n.Lexpr.(type) {
221242
case nodes.ColumnRef:
243+
key := ""
222244
for _, n := range n.Fields.Items {
223245
switch n := n.(type) {
224246
case nodes.String:
225-
q = q.OrderBy(n.Str)
247+
key += n.Str
226248
}
227249
}
250+
if typ, ok := typeMap[key]; ok {
251+
a = append(a, Arg{Name: key, Type: typ})
252+
} else {
253+
panic("unknown column: " + key)
254+
}
228255
}
229-
}
230-
}
231-
return q
232-
}
233-
234-
func parseArgs(t postgres.Table, args []string) []Arg {
235-
a := []Arg{}
236-
for _, arg := range args {
237-
for _, c := range t.Columns {
238-
if c.Name == arg {
239-
a = append(a, Arg{Name: c.Name, Type: "string"})
240-
}
256+
default:
257+
panic(fmt.Sprintf("unsupported type: %T", n))
241258
}
242259
}
243260
return a
@@ -320,8 +337,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
320337
}
321338
322339
{{range .Queries}}
323-
const {{.QueryName}} = {{$.Q}}
324-
{{.SQL}}
340+
const {{.QueryName}} = {{$.Q}}{{.SQL}}
325341
{{$.Q}}
326342
327343
{{if eq .Type ":one"}}

parser_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,70 @@ import (
66
"log"
77
"path/filepath"
88
"testing"
9+
10+
pg "github.com/lfittl/pg_query_go"
11+
nodes "github.com/lfittl/pg_query_go/nodes"
912
)
1013

14+
const pluck = `
15+
SELECT * FROM venue WHERE slug = $1 AND city = $2;
16+
SELECT * FROM venue WHERE slug = $1;
17+
SELECT * FROM venue LIMIT $1;
18+
SELECT * FROM venue OFFSET $1;
19+
`
20+
21+
func TestPluck(t *testing.T) {
22+
tree, err := pg.Parse(pluck)
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
expected := []string{
28+
"SELECT * FROM venue WHERE slug = $1 AND city = $2",
29+
"SELECT * FROM venue WHERE slug = $1",
30+
"SELECT * FROM venue LIMIT $1",
31+
"SELECT * FROM venue OFFSET $1",
32+
}
33+
34+
for i, stmt := range tree.Statements {
35+
switch n := stmt.(type) {
36+
case nodes.RawStmt:
37+
q, err := pluckQuery(pluck, n)
38+
if err != nil {
39+
t.Error(err)
40+
continue
41+
}
42+
if q != expected[i] {
43+
t.Errorf("expected %s, got %s", expected[i], q)
44+
}
45+
default:
46+
t.Fatalf("wrong type; %T", n)
47+
}
48+
}
49+
}
50+
51+
func TestExtractArgs(t *testing.T) {
52+
queries := []string{
53+
"SELECT * FROM venue WHERE slug = $1 AND city = $2",
54+
"SELECT * FROM venue WHERE slug = $1",
55+
"SELECT * FROM venue LIMIT $1",
56+
"SELECT * FROM venue OFFSET $1",
57+
}
58+
for _, q := range queries {
59+
tree, err := pg.Parse(q)
60+
if err != nil {
61+
t.Fatal(err)
62+
}
63+
for _, stmt := range tree.Statements {
64+
refs := extractArgs(stmt)
65+
if err != nil {
66+
t.Error(err)
67+
}
68+
t.Logf("refs: %#v", refs)
69+
}
70+
}
71+
}
72+
1173
func TestParseSchema(t *testing.T) {
1274
s, err := ParseSchmea(filepath.Join("testdata", "ondeck", "schema"))
1375
if err != nil {

testdata/ondeck/db.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
6363
}
6464
}
6565

66-
const getCity = `
67-
SELECT slug, name FROM city WHERE slug = $1
66+
const getCity = `-- name: GetCity :one
67+
SELECT slug, name
68+
FROM city
69+
WHERE slug = $1
6870
`
6971

7072
func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
@@ -82,8 +84,10 @@ func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) {
8284
return i, err
8385
}
8486

85-
const listCities = `
86-
SELECT slug, name FROM city ORDER BY name
87+
const listCities = `-- name: ListCities :many
88+
SELECT slug, name
89+
FROM city
90+
ORDER BY name
8791
`
8892

8993
func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
@@ -118,8 +122,11 @@ func (q *Queries) ListCities(ctx context.Context) ([]City, error) {
118122
return items, nil
119123
}
120124

121-
const listVenues = `
122-
SELECT slug, name, city, spotify_playlist, songkick_id FROM venue WHERE city = $1 ORDER BY name
125+
const listVenues = `-- name: ListVenues :many
126+
SELECT slug, name, city, spotify_playlist, songkick_id
127+
FROM venue
128+
WHERE city = $1
129+
ORDER BY name
123130
`
124131

125132
func (q *Queries) ListVenues(ctx context.Context, city string) ([]Venue, error) {

0 commit comments

Comments
 (0)