Skip to content

Commit b489f5a

Browse files
committed
feat(createdb): Create ephemeral databases
1 parent 68d27bf commit b489f5a

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

internal/cmd/cmd.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
)
2626

2727
func init() {
28+
createDBCmd.Flags().StringP("env", "e", "DATABASE_URL", "environment variable to set (default: DATABASE_URL)")
2829
uploadCmd.Flags().BoolP("dry-run", "", false, "dump upload request (default: false)")
2930
initCmd.Flags().BoolP("v1", "", false, "generate v1 config yaml file")
3031
initCmd.Flags().BoolP("v2", "", true, "generate v2 config yaml file")
@@ -41,6 +42,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
4142
rootCmd.PersistentFlags().Bool("no-database", false, "disable database connections (default: false)")
4243

4344
rootCmd.AddCommand(checkCmd)
45+
rootCmd.AddCommand(createDBCmd)
4446
rootCmd.AddCommand(diffCmd)
4547
rootCmd.AddCommand(genCmd)
4648
rootCmd.AddCommand(initCmd)

internal/cmd/createdb.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"os/exec"
9+
"runtime/trace"
10+
"strings"
11+
12+
"github.com/spf13/cobra"
13+
"github.com/sqlc-dev/sqlc/internal/config"
14+
"github.com/sqlc-dev/sqlc/internal/opts"
15+
"github.com/sqlc-dev/sqlc/internal/quickdb"
16+
pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
17+
"github.com/sqlc-dev/sqlc/internal/sql/sqlpath"
18+
)
19+
20+
var createDBCmd = &cobra.Command{
21+
Use: "createdb",
22+
Short: "Create an ephemeral database",
23+
Args: cobra.MinimumNArgs(1),
24+
RunE: func(cmd *cobra.Command, args []string) error {
25+
defer trace.StartRegion(cmd.Context(), "createdb").End()
26+
stderr := cmd.ErrOrStderr()
27+
dir, name := getConfigPath(stderr, cmd.Flag("file"))
28+
env, err := cmd.Flags().GetString("env")
29+
if err != nil {
30+
return err
31+
}
32+
code, err := CreateDB(cmd.Context(), dir, name, args, env, &Options{
33+
Env: ParseEnv(cmd),
34+
Stderr: stderr,
35+
})
36+
if err != nil {
37+
fmt.Fprintln(stderr, err.Error())
38+
os.Exit(code)
39+
}
40+
return nil
41+
},
42+
}
43+
44+
func CreateDB(ctx context.Context, dir, filename string, args []string, env string, o *Options) (int, error) {
45+
dbg := opts.DebugFromEnv()
46+
if !dbg.ProcessPlugins {
47+
return 1, fmt.Errorf("process-plugins disabled")
48+
}
49+
_, conf, err := o.ReadConfig(dir, filename)
50+
if err != nil {
51+
return 1, err
52+
}
53+
// Find the first SQL with a managed database
54+
var pkg *config.SQL
55+
for _, sql := range conf.SQL {
56+
sql := sql
57+
if sql.Database != nil && sql.Database.Managed {
58+
pkg = &sql
59+
break
60+
}
61+
}
62+
if pkg == nil {
63+
return 1, fmt.Errorf("no managed database found")
64+
}
65+
if pkg.Engine != config.EnginePostgreSQL {
66+
return 1, fmt.Errorf("managed: only PostgreSQL currently")
67+
}
68+
69+
var migrations []string
70+
files, err := sqlpath.Glob(pkg.Schema)
71+
if err != nil {
72+
return 1, err
73+
}
74+
for _, schema := range files {
75+
contents, err := os.ReadFile(schema)
76+
if err != nil {
77+
return 1, fmt.Errorf("read file: %w", err)
78+
}
79+
migrations = append(migrations, string(contents))
80+
}
81+
client, err := quickdb.NewClientFromConfig(conf.Cloud)
82+
if err != nil {
83+
return 1, fmt.Errorf("client error: %w", err)
84+
}
85+
86+
resp, err := client.CreateEphemeralDatabase(ctx, &pb.CreateEphemeralDatabaseRequest{
87+
Engine: "postgresql",
88+
Region: quickdb.GetClosestRegion(),
89+
Migrations: migrations,
90+
})
91+
if err != nil {
92+
return 1, fmt.Errorf("managed: create database: %w", err)
93+
}
94+
95+
defer func() {
96+
client.DropEphemeralDatabase(ctx, &pb.DropEphemeralDatabaseRequest{
97+
DatabaseId: resp.DatabaseId,
98+
})
99+
}()
100+
101+
cmd := exec.Command(args[0], args[1:]...)
102+
cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", env, resp.Uri))
103+
cmd.Stdout = os.Stdout
104+
cmd.Stderr = os.Stderr
105+
cmd.Env = []string{fmt.Sprintf("%s=%s", env, resp.Uri)}
106+
for _, val := range os.Environ() {
107+
if strings.HasPrefix(val, "SQLC_AUTH_TOKEN") {
108+
continue
109+
}
110+
cmd.Env = append(cmd.Env, val)
111+
}
112+
113+
if err := cmd.Run(); err != nil {
114+
var exitErr *exec.ExitError
115+
if errors.As(err, &exitErr) {
116+
return exitErr.ExitCode(), err
117+
}
118+
return 1, err
119+
}
120+
121+
return 0, nil
122+
}

0 commit comments

Comments
 (0)