diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e37c5cc870..55f46a6cbf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,20 +31,21 @@ jobs: services: postgres: - image: postgres:11 + image: "postgres:15" env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres ports: - 5432:5432 # needed because the postgres container does not provide a healthcheck options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 mysql: - image: mysql:8 + image: "mysql/mysql-server:8.0" env: - MYSQL_ROOT_PASSWORD: mysecretpassword MYSQL_DATABASE: mysql + MYSQL_ROOT_HOST: '%' + MYSQL_ROOT_PASSWORD: mysecretpassword ports: - 3306:3306 diff --git a/docker-compose.yml b/docker-compose.yml index 606c70b441..9579b04e0f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,16 +6,16 @@ services: - "3306:3306" restart: always environment: - MYSQL_DATABASE: dinotest + MYSQL_DATABASE: mysql MYSQL_ROOT_PASSWORD: mysecretpassword MYSQL_ROOT_HOST: '%' postgresql: - image: "postgres:13" + image: "postgres:15" ports: - "5432:5432" restart: always environment: - POSTGRES_DB: dinotest + POSTGRES_DB: postgres POSTGRES_PASSWORD: mysecretpassword POSTGRES_USER: postgres diff --git a/examples/authors/sqlc.json b/examples/authors/sqlc.json index 558ea97cbe..50627fe8a5 100644 --- a/examples/authors/sqlc.json +++ b/examples/authors/sqlc.json @@ -19,6 +19,9 @@ "schema": "mysql/schema.sql", "queries": "mysql/query.sql", "engine": "mysql", + "database": { + "url": "'root:%s@tcp(%s:%s)/authors?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])" + }, "gen": { "go": { "package": "authors", diff --git a/examples/booktest/sqlc.json b/examples/booktest/sqlc.json index c0176d1f23..16fa552026 100644 --- a/examples/booktest/sqlc.json +++ b/examples/booktest/sqlc.json @@ -16,7 +16,10 @@ "path": "mysql", "schema": "mysql/schema.sql", "queries": "mysql/query.sql", - "engine": "mysql" + "engine": "mysql", + "database": { + "url": "'root:%s@tcp(%s:%s)/booktest?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])" + } }, { "name": "booktest", diff --git a/examples/ondeck/sqlc.json b/examples/ondeck/sqlc.json index f3ae36698c..d4fd765024 100644 --- a/examples/ondeck/sqlc.json +++ b/examples/ondeck/sqlc.json @@ -20,6 +20,9 @@ "schema": "mysql/schema", "queries": "mysql/query", "engine": "mysql", + "database": { + "url": "'root:%s@tcp(%s:%s)/ondeck?multiStatements=true&parseTime=true'.format([env.MYSQL_ROOT_PASSWORD, env.MYSQL_HOST, env.MYSQL_PORT])" + }, "emit_json_tags": true, "emit_prepared_queries": true, "emit_interface": true diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index f034b057aa..bbdc59c95d 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "database/sql" "errors" "fmt" "io" @@ -11,6 +12,7 @@ import ( "strings" "time" + _ "github.com/go-sql-driver/mysql" "github.com/google/cel-go/cel" "github.com/google/cel-go/ext" "github.com/jackc/pgx/v5" @@ -140,17 +142,6 @@ func Vet(ctx context.Context, e Env, dir, filename string, stderr io.Writer) err return nil } -type checker struct { - Checks map[string]cel.Program - Conf *config.Config - Dbenv *cel.Env - Dir string - Env *cel.Env - Envmap map[string]string - Msgs map[string]string - Stderr io.Writer -} - // Determine if a query can be prepared based on the engine and the statement // type. func prepareable(sql config.SQL, raw *ast.RawStmt) bool { @@ -169,74 +160,134 @@ func prepareable(sql config.SQL, raw *ast.RawStmt) bool { return false } } + // Almost all statements in MySQL can be prepared, so I'm just going to assume they can be + // https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + if sql.Engine == config.EngineMySQL { + return true + } return false } -func (c *checker) checkSQL(ctx context.Context, sql config.SQL) error { +type preparer interface { + Prepare(context.Context, string, string) error +} + +type pgxPreparer struct { + c *pgx.Conn +} + +func (p *pgxPreparer) Prepare(ctx context.Context, name, query string) error { + _, err := p.c.Prepare(ctx, name, query) + return err +} + +type dbPreparer struct { + db *sql.DB +} + +func (p *dbPreparer) Prepare(ctx context.Context, name, query string) error { + _, err := p.db.PrepareContext(ctx, query) + return err +} + +type checker struct { + Checks map[string]cel.Program + Conf *config.Config + Dbenv *cel.Env + Dir string + Env *cel.Env + Envmap map[string]string + Msgs map[string]string + Stderr io.Writer +} + +func (c *checker) DSN(expr string) (string, error) { + ast, issues := c.Dbenv.Compile(expr) + if issues != nil && issues.Err() != nil { + return "", fmt.Errorf("type-check error: database url %s", issues.Err()) + } + prg, err := c.Dbenv.Program(ast) + if err != nil { + return "", fmt.Errorf("program construction error: database url %s", err) + } + // Populate the environment variable map if it is empty + if len(c.Envmap) == 0 { + for _, e := range os.Environ() { + k, v, _ := strings.Cut(e, "=") + c.Envmap[k] = v + } + } + out, _, err := prg.Eval(map[string]any{ + "env": c.Envmap, + }) + if err != nil { + return "", fmt.Errorf("expression error: %s", err) + } + dsn, ok := out.Value().(string) + if !ok { + return "", fmt.Errorf("expression returned non-string value: %v", out.Value()) + } + return dsn, nil +} + +func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { // TODO: Create a separate function for this logic so we can - combo := config.Combine(*c.Conf, sql) + combo := config.Combine(*c.Conf, s) // TODO: This feels like a hack that will bite us later - joined := make([]string, 0, len(sql.Schema)) - for _, s := range sql.Schema { + joined := make([]string, 0, len(s.Schema)) + for _, s := range s.Schema { joined = append(joined, filepath.Join(c.Dir, s)) } - sql.Schema = joined + s.Schema = joined - joined = make([]string, 0, len(sql.Queries)) - for _, q := range sql.Queries { + joined = make([]string, 0, len(s.Queries)) + for _, q := range s.Queries { joined = append(joined, filepath.Join(c.Dir, q)) } - sql.Queries = joined + s.Queries = joined var name string parseOpts := opts.Parser{ Debug: debug.Debug, } - result, failed := parse(ctx, name, c.Dir, sql, combo, parseOpts, c.Stderr) + result, failed := parse(ctx, name, c.Dir, s, combo, parseOpts, c.Stderr) if failed { return ErrFailedChecks } // TODO: Add MySQL support - var pgconn *pgx.Conn - if sql.Engine == config.EnginePostgreSQL && sql.Database != nil { - ast, issues := c.Dbenv.Compile(sql.Database.URL) - if issues != nil && issues.Err() != nil { - return fmt.Errorf("type-check error: database url %s", issues.Err()) - } - prg, err := c.Dbenv.Program(ast) + var prep preparer + if s.Database != nil { + dburl, err := c.DSN(s.Database.URL) if err != nil { - return fmt.Errorf("program construction error: database url %s", err) + return err } - // Populate the environment variable map if it is empty - if len(c.Envmap) == 0 { - for _, e := range os.Environ() { - k, v, _ := strings.Cut(e, "=") - c.Envmap[k] = v + switch s.Engine { + case config.EnginePostgreSQL: + conn, err := pgx.Connect(ctx, dburl) + if err != nil { + return fmt.Errorf("database: connection error: %s", err) } + if err := conn.Ping(ctx); err != nil { + return fmt.Errorf("database: connection error: %s", err) + } + defer conn.Close(ctx) + prep = &pgxPreparer{conn} + case config.EngineMySQL: + db, err := sql.Open("mysql", dburl) + if err != nil { + return fmt.Errorf("database: connection error: %s", err) + } + if err := db.PingContext(ctx); err != nil { + return fmt.Errorf("database: connection error: %s", err) + } + defer db.Close() + prep = &dbPreparer{db} + default: + return fmt.Errorf("unsupported database url: %s", s.Engine) } - out, _, err := prg.Eval(map[string]any{ - "env": c.Envmap, - }) - if err != nil { - return fmt.Errorf("expression error: %s", err) - } - dburl, ok := out.Value().(string) - if !ok { - return fmt.Errorf("expression returned non-string value: %v", out.Value()) - } - fmt.Println("URL", dburl) - conn, err := pgx.Connect(ctx, dburl) - if err != nil { - return fmt.Errorf("database: connection error: %s", err) - } - if err := conn.Ping(ctx); err != nil { - return fmt.Errorf("database: connection error: %s", err) - } - defer conn.Close(ctx) - pgconn = conn } errored := false @@ -244,17 +295,16 @@ func (c *checker) checkSQL(ctx context.Context, sql config.SQL) error { cfg := vetConfig(req) for i, query := range req.Queries { original := result.Queries[i] - if pgconn != nil && prepareable(sql, original.RawStmt) { + if prep != nil && prepareable(s, original.RawStmt) { name := fmt.Sprintf("sqlc_vet_%d_%d", time.Now().Unix(), i) - _, err := pgconn.Prepare(ctx, name, query.Text) - if err != nil { + if err := prep.Prepare(ctx, name, query.Text); err != nil { fmt.Fprintf(c.Stderr, "%s: error preparing %s: %s\n", query.Filename, query.Name, err) errored = true continue } } q := vetQuery(query) - for _, name := range sql.Rules { + for _, name := range s.Rules { prg, ok := c.Checks[name] if !ok { return fmt.Errorf("type-check error: a check with the name '%s' does not exist", name) diff --git a/internal/endtoend/vet_test.go b/internal/endtoend/vet_test.go index 94108107dc..50f62512a9 100644 --- a/internal/endtoend/vet_test.go +++ b/internal/endtoend/vet_test.go @@ -14,18 +14,16 @@ import ( "github.com/kyleconroy/sqlc/internal/sqltest" ) -func findSchema(t *testing.T, path string) string { - t.Helper() - schemaFile := filepath.Join(path, "postgresql", "schema.sql") +func findSchema(t *testing.T, path string) (string, bool) { + schemaFile := filepath.Join(path, "schema.sql") if _, err := os.Stat(schemaFile); !os.IsNotExist(err) { - return schemaFile + return schemaFile, true } - schemaDir := filepath.Join(path, "postgresql", "schema") + schemaDir := filepath.Join(path, "schema") if _, err := os.Stat(schemaDir); !os.IsNotExist(err) { - return schemaDir + return schemaDir, true } - t.Fatalf("error: can't find schema files in %s", path) - return "" + return "", false } func TestExamplesVet(t *testing.T) { @@ -52,9 +50,16 @@ func TestExamplesVet(t *testing.T) { path := filepath.Join(examples, tc) if tc != "kotlin" && tc != "python" { - sqltest.CreatePostgreSQLDatabase(t, tc, []string{ - findSchema(t, path), - }) + if s, found := findSchema(t, filepath.Join(path, "postgresql")); found { + db, cleanup := sqltest.CreatePostgreSQLDatabase(t, tc, false, []string{s}) + defer db.Close() + defer cleanup() + } + if s, found := findSchema(t, filepath.Join(path, "mysql")); found { + db, cleanup := sqltest.CreateMySQLDatabase(t, tc, []string{s}) + defer db.Close() + defer cleanup() + } } var stderr bytes.Buffer diff --git a/internal/sqltest/mysql.go b/internal/sqltest/mysql.go index 4e64dc5eb6..2cbda0b0d2 100644 --- a/internal/sqltest/mysql.go +++ b/internal/sqltest/mysql.go @@ -13,6 +13,12 @@ import ( ) func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) { + // For each test, pick a new database name at random. + name := "sqltest_mysql_" + id() + return CreateMySQLDatabase(t, name, migrations) +} + +func CreateMySQLDatabase(t *testing.T, name string, migrations []string) (*sql.DB, func()) { t.Helper() data := os.Getenv("MYSQL_DATABASE") @@ -49,13 +55,11 @@ func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) { t.Fatal(err) } - // For each test, pick a new database name at random. - dbName := "sqltest_mysql_" + id() - if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil { + if _, err := db.Exec("CREATE DATABASE " + name); err != nil { t.Fatal(err) } - source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbName) + source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, name) sdb, err := sql.Open("mysql", source) if err != nil { t.Fatal(err) @@ -77,7 +81,7 @@ func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) { return sdb, func() { // Drop the test db after test runs - if _, err := db.Exec("DROP DATABASE " + dbName); err != nil { + if _, err := db.Exec("DROP DATABASE " + name); err != nil { t.Fatal(err) } } diff --git a/internal/sqltest/postgres.go b/internal/sqltest/postgres.go index ad58e2c019..682585daa4 100644 --- a/internal/sqltest/postgres.go +++ b/internal/sqltest/postgres.go @@ -28,9 +28,13 @@ func id() string { return string(b) } -// Disable random new schema -// Override database name -func CreatePostgreSQLDatabase(t *testing.T, newDB string, migrations []string) *sql.DB { +func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) { + // For each test, pick a new schema name at random. + schema := "sqltest_postgresql_" + id() + return CreatePostgreSQLDatabase(t, schema, true, migrations) +} + +func CreatePostgreSQLDatabase(t *testing.T, name string, schema bool, migrations []string) (*sql.DB, func()) { t.Helper() pgUser := os.Getenv("PG_USER") @@ -66,92 +70,24 @@ func CreatePostgreSQLDatabase(t *testing.T, newDB string, migrations []string) * if err != nil { t.Fatal(err) } - defer db.Close() - - var exists bool - dberr := db.QueryRow(`SELECT true FROM pg_database WHERE datname = $1`, newDB).Scan(&exists) - if dberr != nil && dberr != sql.ErrNoRows { - t.Fatal(err) - } - if !exists { - if _, err := db.Exec("CREATE DATABASE " + newDB); err != nil { + // For each test, pick a new schema name at random. + var newsource, dropQuery string + if schema { + if _, err := db.Exec("CREATE SCHEMA " + name); err != nil { t.Fatal(err) } + newsource = source + "&search_path=" + name + dropQuery = "DROP SCHEMA " + name + " CASCADE" } else { - t.Logf("database '%s' exists, not creating", newDB) - } - - newSource := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, newDB) - t.Logf("newdb: %s", newSource) - - sdb, err := sql.Open("postgres", newSource) - if err != nil { - t.Fatal(err) - } - - if !exists { - files, err := sqlpath.Glob(migrations) - if err != nil { + if _, err := db.Exec("CREATE DATABASE " + name); err != nil { t.Fatal(err) } - for _, f := range files { - blob, err := os.ReadFile(f) - if err != nil { - t.Fatal(err) - } - if _, err := sdb.Exec(string(blob)); err != nil { - t.Fatalf("%s: %s", filepath.Base(f), err) - } - } - } - return sdb -} - -func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) { - t.Helper() - - pgUser := os.Getenv("PG_USER") - pgHost := os.Getenv("PG_HOST") - pgPort := os.Getenv("PG_PORT") - pgPass := os.Getenv("PG_PASSWORD") - pgDB := os.Getenv("PG_DATABASE") - - if pgUser == "" { - pgUser = "postgres" - } - - if pgPass == "" { - pgPass = "mysecretpassword" - } - - if pgPort == "" { - pgPort = "5432" - } - - if pgHost == "" { - pgHost = "127.0.0.1" - } - - if pgDB == "" { - pgDB = "dinotest" - } - - source := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, pgDB) - t.Logf("db: %s", source) - - db, err := sql.Open("postgres", source) - if err != nil { - t.Fatal(err) - } - - // For each test, pick a new schema name at random. - schema := "sqltest_postgresql_" + id() - if _, err := db.Exec("CREATE SCHEMA " + schema); err != nil { - t.Fatal(err) + newsource = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, name) + dropQuery = "DROP DATABASE IF EXISTS " + name + " WITH (FORCE)" } - sdb, err := sql.Open("postgres", source+"&search_path="+schema) + sdb, err := sql.Open("postgres", newsource) if err != nil { t.Fatal(err) } @@ -171,8 +107,9 @@ func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) { } return sdb, func() { - if _, err := db.Exec("DROP SCHEMA " + schema + " CASCADE"); err != nil { + if _, err := db.Exec(dropQuery); err != nil { t.Fatal(err) } + db.Close() } }