@@ -2,6 +2,7 @@ package cmd
2
2
3
3
import (
4
4
"context"
5
+ "database/sql"
5
6
"errors"
6
7
"fmt"
7
8
"io"
@@ -11,6 +12,7 @@ import (
11
12
"strings"
12
13
"time"
13
14
15
+ _ "github.com/go-sql-driver/mysql"
14
16
"github.com/google/cel-go/cel"
15
17
"github.com/google/cel-go/ext"
16
18
"github.com/jackc/pgx/v5"
@@ -140,17 +142,6 @@ func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) err
140
142
return nil
141
143
}
142
144
143
- type checker struct {
144
- Checks map [string ]cel.Program
145
- Conf * config.Config
146
- Dbenv * cel.Env
147
- Dir string
148
- Env * cel.Env
149
- Envmap map [string ]string
150
- Msgs map [string ]string
151
- Stderr io.Writer
152
- }
153
-
154
145
// Determine if a query can be prepared based on the engine and the statement
155
146
// type.
156
147
func prepareable (sql config.SQL , raw * ast.RawStmt ) bool {
@@ -169,92 +160,151 @@ func prepareable(sql config.SQL, raw *ast.RawStmt) bool {
169
160
return false
170
161
}
171
162
}
163
+ // Almost all statements in MySQL can be prepared, so I'm just going to assume they can be
164
+ // https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
165
+ if sql .Engine == config .EngineMySQL {
166
+ return true
167
+ }
172
168
return false
173
169
}
174
170
175
- func (c * checker ) checkSQL (ctx context.Context , sql config.SQL ) error {
171
+ type preparer interface {
172
+ Prepare (context.Context , string , string ) error
173
+ }
174
+
175
+ type pgxPreparer struct {
176
+ c * pgx.Conn
177
+ }
178
+
179
+ func (p * pgxPreparer ) Prepare (ctx context.Context , name , query string ) error {
180
+ _ , err := p .c .Prepare (ctx , name , query )
181
+ return err
182
+ }
183
+
184
+ type dbPreparer struct {
185
+ db * sql.DB
186
+ }
187
+
188
+ func (p * dbPreparer ) Prepare (ctx context.Context , name , query string ) error {
189
+ _ , err := p .db .PrepareContext (ctx , query )
190
+ return err
191
+ }
192
+
193
+ type checker struct {
194
+ Checks map [string ]cel.Program
195
+ Conf * config.Config
196
+ Dbenv * cel.Env
197
+ Dir string
198
+ Env * cel.Env
199
+ Envmap map [string ]string
200
+ Msgs map [string ]string
201
+ Stderr io.Writer
202
+ }
203
+
204
+ func (c * checker ) DSN (expr string ) (string , error ) {
205
+ ast , issues := c .Dbenv .Compile (expr )
206
+ if issues != nil && issues .Err () != nil {
207
+ return "" , fmt .Errorf ("type-check error: database url %s" , issues .Err ())
208
+ }
209
+ prg , err := c .Dbenv .Program (ast )
210
+ if err != nil {
211
+ return "" , fmt .Errorf ("program construction error: database url %s" , err )
212
+ }
213
+ // Populate the environment variable map if it is empty
214
+ if len (c .Envmap ) == 0 {
215
+ for _ , e := range os .Environ () {
216
+ k , v , _ := strings .Cut (e , "=" )
217
+ c .Envmap [k ] = v
218
+ }
219
+ }
220
+ out , _ , err := prg .Eval (map [string ]any {
221
+ "env" : c .Envmap ,
222
+ })
223
+ if err != nil {
224
+ return "" , fmt .Errorf ("expression error: %s" , err )
225
+ }
226
+ dsn , ok := out .Value ().(string )
227
+ if ! ok {
228
+ return "" , fmt .Errorf ("expression returned non-string value: %v" , out .Value ())
229
+ }
230
+ return dsn , nil
231
+ }
232
+
233
+ func (c * checker ) checkSQL (ctx context.Context , s config.SQL ) error {
176
234
// TODO: Create a separate function for this logic so we can
177
- combo := config .Combine (* c .Conf , sql )
235
+ combo := config .Combine (* c .Conf , s )
178
236
179
237
// TODO: This feels like a hack that will bite us later
180
- joined := make ([]string , 0 , len (sql .Schema ))
181
- for _ , s := range sql .Schema {
238
+ joined := make ([]string , 0 , len (s .Schema ))
239
+ for _ , s := range s .Schema {
182
240
joined = append (joined , filepath .Join (c .Dir , s ))
183
241
}
184
- sql .Schema = joined
242
+ s .Schema = joined
185
243
186
- joined = make ([]string , 0 , len (sql .Queries ))
187
- for _ , q := range sql .Queries {
244
+ joined = make ([]string , 0 , len (s .Queries ))
245
+ for _ , q := range s .Queries {
188
246
joined = append (joined , filepath .Join (c .Dir , q ))
189
247
}
190
- sql .Queries = joined
248
+ s .Queries = joined
191
249
192
250
var name string
193
251
parseOpts := opts.Parser {
194
252
Debug : debug .Debug ,
195
253
}
196
254
197
- result , failed := parse (ctx , name , c .Dir , sql , combo , parseOpts , c .Stderr )
255
+ result , failed := parse (ctx , name , c .Dir , s , combo , parseOpts , c .Stderr )
198
256
if failed {
199
257
return ErrFailedChecks
200
258
}
201
259
202
260
// TODO: Add MySQL support
203
- var pgconn * pgx.Conn
204
- if sql .Engine == config .EnginePostgreSQL && sql .Database != nil {
205
- ast , issues := c .Dbenv .Compile (sql .Database .URL )
206
- if issues != nil && issues .Err () != nil {
207
- return fmt .Errorf ("type-check error: database url %s" , issues .Err ())
208
- }
209
- prg , err := c .Dbenv .Program (ast )
261
+ var prep preparer
262
+ if s .Database != nil {
263
+ dburl , err := c .DSN (s .Database .URL )
210
264
if err != nil {
211
- return fmt . Errorf ( "program construction error: database url %s" , err )
265
+ return err
212
266
}
213
- // Populate the environment variable map if it is empty
214
- if len ( c . Envmap ) == 0 {
215
- for _ , e := range os . Environ () {
216
- k , v , _ := strings . Cut ( e , "=" )
217
- c . Envmap [ k ] = v
267
+ switch s . Engine {
268
+ case config . EnginePostgreSQL :
269
+ conn , err := pgx . Connect ( ctx , dburl )
270
+ if err != nil {
271
+ return fmt . Errorf ( "database: connection error: %s" , err )
218
272
}
273
+ if err := conn .Ping (ctx ); err != nil {
274
+ return fmt .Errorf ("database: connection error: %s" , err )
275
+ }
276
+ defer conn .Close (ctx )
277
+ prep = & pgxPreparer {conn }
278
+ case config .EngineMySQL :
279
+ db , err := sql .Open ("mysql" , dburl )
280
+ if err != nil {
281
+ return fmt .Errorf ("database: connection error: %s" , err )
282
+ }
283
+ if err := db .PingContext (ctx ); err != nil {
284
+ return fmt .Errorf ("database: connection error: %s" , err )
285
+ }
286
+ defer db .Close ()
287
+ prep = & dbPreparer {db }
288
+ default :
289
+ return fmt .Errorf ("unsupported database url: %s" , s .Engine )
219
290
}
220
- out , _ , err := prg .Eval (map [string ]any {
221
- "env" : c .Envmap ,
222
- })
223
- if err != nil {
224
- return fmt .Errorf ("expression error: %s" , err )
225
- }
226
- dburl , ok := out .Value ().(string )
227
- if ! ok {
228
- return fmt .Errorf ("expression returned non-string value: %v" , out .Value ())
229
- }
230
- fmt .Println ("URL" , dburl )
231
- conn , err := pgx .Connect (ctx , dburl )
232
- if err != nil {
233
- return fmt .Errorf ("database: connection error: %s" , err )
234
- }
235
- if err := conn .Ping (ctx ); err != nil {
236
- return fmt .Errorf ("database: connection error: %s" , err )
237
- }
238
- defer conn .Close (ctx )
239
- pgconn = conn
240
291
}
241
292
242
293
errored := false
243
294
req := codeGenRequest (result , combo )
244
295
cfg := vetConfig (req )
245
296
for i , query := range req .Queries {
246
297
original := result .Queries [i ]
247
- if pgconn != nil && prepareable (sql , original .RawStmt ) {
298
+ if prep != nil && prepareable (s , original .RawStmt ) {
248
299
name := fmt .Sprintf ("sqlc_vet_%d_%d" , time .Now ().Unix (), i )
249
- _ , err := pgconn .Prepare (ctx , name , query .Text )
250
- if err != nil {
300
+ if err := prep .Prepare (ctx , name , query .Text ); err != nil {
251
301
fmt .Fprintf (c .Stderr , "%s: error preparing %s: %s\n " , query .Filename , query .Name , err )
252
302
errored = true
253
303
continue
254
304
}
255
305
}
256
306
q := vetQuery (query )
257
- for _ , name := range sql .Rules {
307
+ for _ , name := range s .Rules {
258
308
prg , ok := c .Checks [name ]
259
309
if ! ok {
260
310
return fmt .Errorf ("type-check error: a check with the name '%s' does not exist" , name )
0 commit comments