Skip to content

Commit 7fd89bc

Browse files
authored
feat(cmd/vet): Prepare queries for MySQL (#2388)
* feat(cmd/vet): Prepare statements for MySQL * fix: Return an error unsupported engines * chore: Use PostgreSQL 15 and MySQL 8 * chore: Fix new MySQL image
1 parent 41c1520 commit 7fd89bc

File tree

9 files changed

+170
-164
lines changed

9 files changed

+170
-164
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ jobs:
3131

3232
services:
3333
postgres:
34-
image: postgres:11
34+
image: "postgres:15"
3535
env:
36-
POSTGRES_USER: postgres
37-
POSTGRES_PASSWORD: postgres
3836
POSTGRES_DB: postgres
37+
POSTGRES_PASSWORD: postgres
38+
POSTGRES_USER: postgres
3939
ports:
4040
- 5432:5432
4141
# needed because the postgres container does not provide a healthcheck
4242
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
4343
mysql:
44-
image: mysql:8
44+
image: "mysql/mysql-server:8.0"
4545
env:
46-
MYSQL_ROOT_PASSWORD: mysecretpassword
4746
MYSQL_DATABASE: mysql
47+
MYSQL_ROOT_HOST: '%'
48+
MYSQL_ROOT_PASSWORD: mysecretpassword
4849
ports:
4950
- 3306:3306
5051

docker-compose.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ services:
66
- "3306:3306"
77
restart: always
88
environment:
9-
MYSQL_DATABASE: dinotest
9+
MYSQL_DATABASE: mysql
1010
MYSQL_ROOT_PASSWORD: mysecretpassword
1111
MYSQL_ROOT_HOST: '%'
1212

1313
postgresql:
14-
image: "postgres:13"
14+
image: "postgres:15"
1515
ports:
1616
- "5432:5432"
1717
restart: always
1818
environment:
19-
POSTGRES_DB: dinotest
19+
POSTGRES_DB: postgres
2020
POSTGRES_PASSWORD: mysecretpassword
2121
POSTGRES_USER: postgres

examples/authors/sqlc.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
"schema": "mysql/schema.sql",
2020
"queries": "mysql/query.sql",
2121
"engine": "mysql",
22+
"database": {
23+
"url": "'root:%s@tcp(%s:%s)/authors?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])"
24+
},
2225
"gen": {
2326
"go": {
2427
"package": "authors",

examples/booktest/sqlc.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
"path": "mysql",
1717
"schema": "mysql/schema.sql",
1818
"queries": "mysql/query.sql",
19-
"engine": "mysql"
19+
"engine": "mysql",
20+
"database": {
21+
"url": "'root:%s@tcp(%s:%s)/booktest?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])"
22+
}
2023
},
2124
{
2225
"name": "booktest",

examples/ondeck/sqlc.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
"schema": "mysql/schema",
2121
"queries": "mysql/query",
2222
"engine": "mysql",
23+
"database": {
24+
"url": "'root:%s@tcp(%s:%s)/ondeck?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])"
25+
},
2326
"emit_json_tags": true,
2427
"emit_prepared_queries": true,
2528
"emit_interface": true

internal/cmd/vet.go

Lines changed: 107 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cmd
22

33
import (
44
"context"
5+
"database/sql"
56
"errors"
67
"fmt"
78
"io"
@@ -11,6 +12,7 @@ import (
1112
"strings"
1213
"time"
1314

15+
_ "github.com/go-sql-driver/mysql"
1416
"github.com/google/cel-go/cel"
1517
"github.com/google/cel-go/ext"
1618
"github.com/jackc/pgx/v5"
@@ -140,17 +142,6 @@ func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) err
140142
return nil
141143
}
142144

143-
type checker struct {
144-
Checks map[string]cel.Program
145-
Conf *config.Config
146-
Dbenv *cel.Env
147-
Dir string
148-
Env *cel.Env
149-
Envmap map[string]string
150-
Msgs map[string]string
151-
Stderr io.Writer
152-
}
153-
154145
// Determine if a query can be prepared based on the engine and the statement
155146
// type.
156147
func prepareable(sql config.SQL, raw *ast.RawStmt) bool {
@@ -169,92 +160,151 @@ func prepareable(sql config.SQL, raw *ast.RawStmt) bool {
169160
return false
170161
}
171162
}
163+
// Almost all statements in MySQL can be prepared, so I'm just going to assume they can be
164+
// https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
165+
if sql.Engine == config.EngineMySQL {
166+
return true
167+
}
172168
return false
173169
}
174170

175-
func (c *checker) checkSQL(ctx context.Context, sql config.SQL) error {
171+
type preparer interface {
172+
Prepare(context.Context, string, string) error
173+
}
174+
175+
type pgxPreparer struct {
176+
c *pgx.Conn
177+
}
178+
179+
func (p *pgxPreparer) Prepare(ctx context.Context, name, query string) error {
180+
_, err := p.c.Prepare(ctx, name, query)
181+
return err
182+
}
183+
184+
type dbPreparer struct {
185+
db *sql.DB
186+
}
187+
188+
func (p *dbPreparer) Prepare(ctx context.Context, name, query string) error {
189+
_, err := p.db.PrepareContext(ctx, query)
190+
return err
191+
}
192+
193+
type checker struct {
194+
Checks map[string]cel.Program
195+
Conf *config.Config
196+
Dbenv *cel.Env
197+
Dir string
198+
Env *cel.Env
199+
Envmap map[string]string
200+
Msgs map[string]string
201+
Stderr io.Writer
202+
}
203+
204+
func (c *checker) DSN(expr string) (string, error) {
205+
ast, issues := c.Dbenv.Compile(expr)
206+
if issues != nil && issues.Err() != nil {
207+
return "", fmt.Errorf("type-check error: database url %s", issues.Err())
208+
}
209+
prg, err := c.Dbenv.Program(ast)
210+
if err != nil {
211+
return "", fmt.Errorf("program construction error: database url %s", err)
212+
}
213+
// Populate the environment variable map if it is empty
214+
if len(c.Envmap) == 0 {
215+
for _, e := range os.Environ() {
216+
k, v, _ := strings.Cut(e, "=")
217+
c.Envmap[k] = v
218+
}
219+
}
220+
out, _, err := prg.Eval(map[string]any{
221+
"env": c.Envmap,
222+
})
223+
if err != nil {
224+
return "", fmt.Errorf("expression error: %s", err)
225+
}
226+
dsn, ok := out.Value().(string)
227+
if !ok {
228+
return "", fmt.Errorf("expression returned non-string value: %v", out.Value())
229+
}
230+
return dsn, nil
231+
}
232+
233+
func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
176234
// TODO: Create a separate function for this logic so we can
177-
combo := config.Combine(*c.Conf, sql)
235+
combo := config.Combine(*c.Conf, s)
178236

179237
// TODO: This feels like a hack that will bite us later
180-
joined := make([]string, 0, len(sql.Schema))
181-
for _, s := range sql.Schema {
238+
joined := make([]string, 0, len(s.Schema))
239+
for _, s := range s.Schema {
182240
joined = append(joined, filepath.Join(c.Dir, s))
183241
}
184-
sql.Schema = joined
242+
s.Schema = joined
185243

186-
joined = make([]string, 0, len(sql.Queries))
187-
for _, q := range sql.Queries {
244+
joined = make([]string, 0, len(s.Queries))
245+
for _, q := range s.Queries {
188246
joined = append(joined, filepath.Join(c.Dir, q))
189247
}
190-
sql.Queries = joined
248+
s.Queries = joined
191249

192250
var name string
193251
parseOpts := opts.Parser{
194252
Debug: debug.Debug,
195253
}
196254

197-
result, failed := parse(ctx, name, c.Dir, sql, combo, parseOpts, c.Stderr)
255+
result, failed := parse(ctx, name, c.Dir, s, combo, parseOpts, c.Stderr)
198256
if failed {
199257
return ErrFailedChecks
200258
}
201259

202260
// TODO: Add MySQL support
203-
var pgconn *pgx.Conn
204-
if sql.Engine == config.EnginePostgreSQL && sql.Database != nil {
205-
ast, issues := c.Dbenv.Compile(sql.Database.URL)
206-
if issues != nil && issues.Err() != nil {
207-
return fmt.Errorf("type-check error: database url %s", issues.Err())
208-
}
209-
prg, err := c.Dbenv.Program(ast)
261+
var prep preparer
262+
if s.Database != nil {
263+
dburl, err := c.DSN(s.Database.URL)
210264
if err != nil {
211-
return fmt.Errorf("program construction error: database url %s", err)
265+
return err
212266
}
213-
// Populate the environment variable map if it is empty
214-
if len(c.Envmap) == 0 {
215-
for _, e := range os.Environ() {
216-
k, v, _ := strings.Cut(e, "=")
217-
c.Envmap[k] = v
267+
switch s.Engine {
268+
case config.EnginePostgreSQL:
269+
conn, err := pgx.Connect(ctx, dburl)
270+
if err != nil {
271+
return fmt.Errorf("database: connection error: %s", err)
218272
}
273+
if err := conn.Ping(ctx); err != nil {
274+
return fmt.Errorf("database: connection error: %s", err)
275+
}
276+
defer conn.Close(ctx)
277+
prep = &pgxPreparer{conn}
278+
case config.EngineMySQL:
279+
db, err := sql.Open("mysql", dburl)
280+
if err != nil {
281+
return fmt.Errorf("database: connection error: %s", err)
282+
}
283+
if err := db.PingContext(ctx); err != nil {
284+
return fmt.Errorf("database: connection error: %s", err)
285+
}
286+
defer db.Close()
287+
prep = &dbPreparer{db}
288+
default:
289+
return fmt.Errorf("unsupported database url: %s", s.Engine)
219290
}
220-
out, _, err := prg.Eval(map[string]any{
221-
"env": c.Envmap,
222-
})
223-
if err != nil {
224-
return fmt.Errorf("expression error: %s", err)
225-
}
226-
dburl, ok := out.Value().(string)
227-
if !ok {
228-
return fmt.Errorf("expression returned non-string value: %v", out.Value())
229-
}
230-
fmt.Println("URL", dburl)
231-
conn, err := pgx.Connect(ctx, dburl)
232-
if err != nil {
233-
return fmt.Errorf("database: connection error: %s", err)
234-
}
235-
if err := conn.Ping(ctx); err != nil {
236-
return fmt.Errorf("database: connection error: %s", err)
237-
}
238-
defer conn.Close(ctx)
239-
pgconn = conn
240291
}
241292

242293
errored := false
243294
req := codeGenRequest(result, combo)
244295
cfg := vetConfig(req)
245296
for i, query := range req.Queries {
246297
original := result.Queries[i]
247-
if pgconn != nil && prepareable(sql, original.RawStmt) {
298+
if prep != nil && prepareable(s, original.RawStmt) {
248299
name := fmt.Sprintf("sqlc_vet_%d_%d", time.Now().Unix(), i)
249-
_, err := pgconn.Prepare(ctx, name, query.Text)
250-
if err != nil {
300+
if err := prep.Prepare(ctx, name, query.Text); err != nil {
251301
fmt.Fprintf(c.Stderr, "%s: error preparing %s: %s\n", query.Filename, query.Name, err)
252302
errored = true
253303
continue
254304
}
255305
}
256306
q := vetQuery(query)
257-
for _, name := range sql.Rules {
307+
for _, name := range s.Rules {
258308
prg, ok := c.Checks[name]
259309
if !ok {
260310
return fmt.Errorf("type-check error: a check with the name '%s' does not exist", name)

internal/endtoend/vet_test.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@ import (
1414
"github.com/kyleconroy/sqlc/internal/sqltest"
1515
)
1616

17-
func findSchema(t *testing.T, path string) string {
18-
t.Helper()
19-
schemaFile := filepath.Join(path, "postgresql", "schema.sql")
17+
func findSchema(t *testing.T, path string) (string, bool) {
18+
schemaFile := filepath.Join(path, "schema.sql")
2019
if _, err := os.Stat(schemaFile); !os.IsNotExist(err) {
21-
return schemaFile
20+
return schemaFile, true
2221
}
23-
schemaDir := filepath.Join(path, "postgresql", "schema")
22+
schemaDir := filepath.Join(path, "schema")
2423
if _, err := os.Stat(schemaDir); !os.IsNotExist(err) {
25-
return schemaDir
24+
return schemaDir, true
2625
}
27-
t.Fatalf("error: can't find schema files in %s", path)
28-
return ""
26+
return "", false
2927
}
3028

3129
func TestExamplesVet(t *testing.T) {
@@ -52,9 +50,16 @@ func TestExamplesVet(t *testing.T) {
5250
path := filepath.Join(examples, tc)
5351

5452
if tc != "kotlin" && tc != "python" {
55-
sqltest.CreatePostgreSQLDatabase(t, tc, []string{
56-
findSchema(t, path),
57-
})
53+
if s, found := findSchema(t, filepath.Join(path, "postgresql")); found {
54+
db, cleanup := sqltest.CreatePostgreSQLDatabase(t, tc, false, []string{s})
55+
defer db.Close()
56+
defer cleanup()
57+
}
58+
if s, found := findSchema(t, filepath.Join(path, "mysql")); found {
59+
db, cleanup := sqltest.CreateMySQLDatabase(t, tc, []string{s})
60+
defer db.Close()
61+
defer cleanup()
62+
}
5863
}
5964

6065
var stderr bytes.Buffer

internal/sqltest/mysql.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ import (
1313
)
1414

1515
func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) {
16+
// For each test, pick a new database name at random.
17+
name := "sqltest_mysql_" + id()
18+
return CreateMySQLDatabase(t, name, migrations)
19+
}
20+
21+
func CreateMySQLDatabase(t *testing.T, name string, migrations []string) (*sql.DB, func()) {
1622
t.Helper()
1723

1824
data := os.Getenv("MYSQL_DATABASE")
@@ -49,13 +55,11 @@ func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) {
4955
t.Fatal(err)
5056
}
5157

52-
// For each test, pick a new database name at random.
53-
dbName := "sqltest_mysql_" + id()
54-
if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
58+
if _, err := db.Exec("CREATE DATABASE " + name); err != nil {
5559
t.Fatal(err)
5660
}
5761

58-
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbName)
62+
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, name)
5963
sdb, err := sql.Open("mysql", source)
6064
if err != nil {
6165
t.Fatal(err)
@@ -77,7 +81,7 @@ func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) {
7781

7882
return sdb, func() {
7983
// Drop the test db after test runs
80-
if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
84+
if _, err := db.Exec("DROP DATABASE " + name); err != nil {
8185
t.Fatal(err)
8286
}
8387
}

0 commit comments

Comments
 (0)