@@ -14,7 +14,6 @@ import (
14
14
"unicode"
15
15
16
16
"github.com/davecgh/go-spew/spew"
17
- sq "github.com/elgris/sqrl"
18
17
"github.com/kyleconroy/strongdb/postgres"
19
18
pg "github.com/lfittl/pg_query_go"
20
19
nodes "github.com/lfittl/pg_query_go/nodes"
@@ -133,7 +132,7 @@ func ParseQueries(s *postgres.Schema, dir string) (*Result, error) {
133
132
if err != nil {
134
133
return nil , err
135
134
}
136
- parseFuncs (s , & r , tree )
135
+ parseFuncs (s , & r , string ( blob ), tree )
137
136
return & r , nil
138
137
}
139
138
return nil , nil
@@ -156,88 +155,106 @@ func parseQueries(t []byte) []Query {
156
155
return q
157
156
}
158
157
159
- func parseFuncs (s * postgres.Schema , r * Result , tree pg.ParsetreeList ) {
158
+ func pluckQuery (source string , n nodes.RawStmt ) (string , error ) {
159
+ // TODO: Bounds checking
160
+ head := n .StmtLocation
161
+ tail := n .StmtLocation + n .StmtLen
162
+ return strings .TrimSpace (source [head :tail ]), nil
163
+ }
164
+
165
+ func parseFuncs (s * postgres.Schema , r * Result , source string , tree pg.ParsetreeList ) {
160
166
for i , stmt := range tree .Statements {
161
167
raw , ok := stmt .(nodes.RawStmt )
162
168
if ! ok {
163
169
continue
164
170
}
171
+
165
172
switch n := raw .Stmt .(type ) {
166
173
case nodes.SelectStmt :
167
174
t := tableName (n )
168
-
169
175
c := columnNames (s , t )
170
- args := []string {}
171
- psql := sq .StatementBuilder .PlaceholderFormat (sq .Dollar )
172
- q := psql .Select (c ... ).From (t )
173
- q , args = where (q , n , args )
174
- q = orderBy (q , n )
175
- query , _ , _ := q .ToSql ()
176
+
177
+ rawSQL , _ := pluckQuery (source , raw )
178
+ refs := extractArgs (n )
176
179
177
180
tab := getTable (s , t )
178
181
r .Queries [i ].Table = tab
179
- r .Queries [i ].Args = parseArgs (tab , args )
180
- r .Queries [i ].SQL = query
182
+ r .Queries [i ].Args = parseArgs (tab , refs )
183
+ r .Queries [i ].SQL = strings . Replace ( rawSQL , "*" , strings . Join ( c , ", " ), 1 )
181
184
default :
182
185
log .Printf ("%T\n " , n )
183
186
}
184
187
}
185
188
}
186
189
187
- func where (q * sq.SelectBuilder , n nodes.SelectStmt , args []string ) (* sq.SelectBuilder , []string ) {
188
- // Only equality supported
189
- eq := sq.Eq {}
190
- found := false
191
- switch a := n .WhereClause .(type ) {
192
- case nodes.A_Expr :
193
- switch n := a .Lexpr .(type ) {
194
- case nodes.ColumnRef :
195
- key := ""
196
- for _ , n := range n .Fields .Items {
197
- switch n := n .(type ) {
198
- case nodes.String :
199
- key += n .Str
200
- }
201
- }
202
- found = true
203
- args = append (args , key )
204
- eq [key ] = "?"
205
- }
206
- // switch n := a.Lexpr.(type) {
207
- // case nodes.ParamRef:
208
- // }
190
+ func extractArgs (n nodes.Node ) []paramRef {
191
+ refs := findRefs ([]paramRef {}, n , nil )
192
+ sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
193
+ return refs
194
+ }
195
+
196
+ type paramRef struct {
197
+ parent nodes.Node
198
+ ref nodes.ParamRef
199
+ }
200
+
201
+ func findRefs (r []paramRef , parent , n nodes.Node ) []paramRef {
202
+ if n == nil {
203
+ n = parent
209
204
}
210
- if ! found {
211
- return q , args
205
+ switch n := n .(type ) {
206
+ case nodes.RawStmt :
207
+ r = findRefs (r , n .Stmt , nil )
208
+ case nodes.SelectStmt :
209
+ r = findRefs (r , n .WhereClause , nil )
210
+ r = findRefs (r , n .LimitCount , nil )
211
+ r = findRefs (r , n .LimitOffset , nil )
212
+ case nodes.BoolExpr :
213
+ for _ , item := range n .Args .Items {
214
+ r = findRefs (r , item , nil )
215
+ }
216
+ case nodes.A_Expr :
217
+ r = findRefs (r , n , n .Lexpr )
218
+ r = findRefs (r , n , n .Rexpr )
219
+ case nodes.ParamRef :
220
+ r = append (r , paramRef {
221
+ parent : parent ,
222
+ ref : n ,
223
+ })
224
+ case nodes.ColumnRef :
225
+ case nil :
226
+ default :
227
+ log .Printf ("%T\n " , n )
212
228
}
213
- return q . Where ( eq ), args
229
+ return r
214
230
}
215
231
216
- func orderBy (q * sq.SelectBuilder , n nodes.SelectStmt ) * sq.SelectBuilder {
217
- for _ , n := range n .SortClause .Items {
218
- switch n := n .(type ) {
219
- case nodes.SortBy :
220
- switch n := n .Node .(type ) {
232
+ func parseArgs (t postgres.Table , args []paramRef ) []Arg {
233
+ typeMap := map [string ]string {}
234
+ for _ , c := range t .Columns {
235
+ typeMap [c .Name ] = "string"
236
+ }
237
+ a := []Arg {}
238
+ for _ , ref := range args {
239
+ switch n := ref .parent .(type ) {
240
+ case nodes.A_Expr :
241
+ switch n := n .Lexpr .(type ) {
221
242
case nodes.ColumnRef :
243
+ key := ""
222
244
for _ , n := range n .Fields .Items {
223
245
switch n := n .(type ) {
224
246
case nodes.String :
225
- q = q . OrderBy ( n .Str )
247
+ key += n .Str
226
248
}
227
249
}
250
+ if typ , ok := typeMap [key ]; ok {
251
+ a = append (a , Arg {Name : key , Type : typ })
252
+ } else {
253
+ panic ("unknown column: " + key )
254
+ }
228
255
}
229
- }
230
- }
231
- return q
232
- }
233
-
234
- func parseArgs (t postgres.Table , args []string ) []Arg {
235
- a := []Arg {}
236
- for _ , arg := range args {
237
- for _ , c := range t .Columns {
238
- if c .Name == arg {
239
- a = append (a , Arg {Name : c .Name , Type : "string" })
240
- }
256
+ default :
257
+ panic (fmt .Sprintf ("unsupported type: %T" , n ))
241
258
}
242
259
}
243
260
return a
@@ -320,8 +337,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
320
337
}
321
338
322
339
{{range .Queries}}
323
- const {{.QueryName}} = {{$.Q}}
324
- {{.SQL}}
340
+ const {{.QueryName}} = {{$.Q}}{{.SQL}}
325
341
{{$.Q}}
326
342
327
343
{{if eq .Type ":one"}}
0 commit comments