From 76c698bedd581edfb2ea9754a23dd7e5a2be012b Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Wed, 18 Jan 2023 10:22:42 -0800 Subject: [PATCH 1/2] compiler: Speed up generate --- go.mod | 1 + go.sum | 2 + internal/cmd/cmd.go | 3 +- internal/cmd/generate.go | 135 +++++++++++++++++++++++---------------- 4 files changed, 85 insertions(+), 56 deletions(-) diff --git a/go.mod b/go.mod index 801bf959eb..f79ba17993 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/pganalyze/pg_query_go/v2 v2.2.0 github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 + golang.org/x/sync v0.1.0 google.golang.org/protobuf v1.28.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index b815f4c2b8..37d5c0802e 100644 --- a/go.sum +++ b/go.sum @@ -212,6 +212,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 204da3212d..9376bb72db 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -44,6 +44,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.SetIn(stdin) rootCmd.SetOut(stdout) rootCmd.SetErr(stderr) + rootCmd.SilenceErrors = true ctx := context.Background() if debug.Debug.Trace != "" { @@ -55,9 +56,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int ctx = tracectx defer cleanup() } - if err := rootCmd.ExecuteContext(ctx); err != nil { - fmt.Fprintf(stderr, "%v\n", err) if exitError, ok := err.(*exec.ExitError); ok { return exitError.ExitCode() } else { diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 450cb9a60e..a7eff7dffa 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -8,8 +8,12 @@ import ( "io" "os" "path/filepath" + "runtime" "runtime/trace" "strings" + "sync" + + "golang.org/x/sync/errgroup" "github.com/kyleconroy/sqlc/internal/codegen/golang" "github.com/kyleconroy/sqlc/internal/codegen/json" @@ -159,71 +163,94 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer } } - for _, sql := range pairs { - combo := config.Combine(*conf, sql.SQL) - if sql.Plugin != nil { - combo.Codegen = *sql.Plugin - } + var m sync.Mutex + grp, gctx := errgroup.WithContext(ctx) + grp.SetLimit(runtime.GOMAXPROCS(0)) - // TODO: This feels like a hack that will bite us later - joined := make([]string, 0, len(sql.Schema)) - for _, s := range sql.Schema { - joined = append(joined, filepath.Join(dir, s)) - } - sql.Schema = joined + stderrs := make([]bytes.Buffer, len(pairs)) - joined = make([]string, 0, len(sql.Queries)) - for _, q := range sql.Queries { - joined = append(joined, filepath.Join(dir, q)) - } - sql.Queries = joined + for i, pair := range pairs { + sql := pair + errout := &stderrs[i] - var name, lang string - parseOpts := opts.Parser{ - Debug: debug.Debug, - } + grp.Go(func() error { + combo := config.Combine(*conf, sql.SQL) + if sql.Plugin != nil { + combo.Codegen = *sql.Plugin + } - switch { - case sql.Gen.Go != nil: - name = combo.Go.Package - lang = "golang" + // TODO: This feels like a hack that will bite us later + joined := make([]string, 0, len(sql.Schema)) + for _, s := range sql.Schema { + joined = append(joined, filepath.Join(dir, s)) + } + sql.Schema = joined - case sql.Plugin != nil: - lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin) - name = sql.Plugin.Plugin - } + joined = make([]string, 0, len(sql.Queries)) + for _, q := range sql.Queries { + joined = append(joined, filepath.Join(dir, q)) + } + sql.Queries = joined - packageRegion := trace.StartRegion(ctx, "package") - trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) + var name, lang string + parseOpts := opts.Parser{ + Debug: debug.Debug, + } - result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr) - if failed { - packageRegion.End() - errored = true - break - } + switch { + case sql.Gen.Go != nil: + name = combo.Go.Package + lang = "golang" - out, resp, err := codegen(ctx, combo, sql, result) - if err != nil { - fmt.Fprintf(stderr, "# package %s\n", name) - fmt.Fprintf(stderr, "error generating code: %s\n", err) - errored = true - packageRegion.End() - continue - } + case sql.Plugin != nil: + lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin) + name = sql.Plugin.Plugin + } - files := map[string]string{} - for _, file := range resp.Files { - files[file.Name] = string(file.Contents) - } - for n, source := range files { - filename := filepath.Join(dir, out, n) - output[filename] = source - } - packageRegion.End() - } + packageRegion := trace.StartRegion(gctx, "package") + trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) + + result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout) + if failed { + packageRegion.End() + errored = true + return nil + } + + out, resp, err := codegen(gctx, combo, sql, result) + if err != nil { + fmt.Fprintf(errout, "# package %s\n", name) + fmt.Fprintf(errout, "error generating code: %s\n", err) + errored = true + packageRegion.End() + return nil + } + + files := map[string]string{} + for _, file := range resp.Files { + files[file.Name] = string(file.Contents) + } + m.Lock() + for n, source := range files { + filename := filepath.Join(dir, out, n) + output[filename] = source + } + m.Unlock() + + packageRegion.End() + return nil + }) + } + if err := grp.Wait(); err != nil { + return nil, err + } if errored { + for i, _ := range stderrs { + if _, err := io.Copy(stderr, &stderrs[i]); err != nil { + return nil, err + } + } return nil, fmt.Errorf("errored") } return output, nil From 20879738f20d579835c74f43c6d34cda474bc21a Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Wed, 18 Jan 2023 16:16:35 -0800 Subject: [PATCH 2/2] wasm: Load serialized modules using singleflight --- internal/ext/wasm/wasm.go | 47 +++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index 4b3314e64f..d75f879e33 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -20,6 +20,7 @@ import ( "strings" wasmtime "github.com/bytecodealliance/wasmtime-go/v3" + "golang.org/x/sync/singleflight" "github.com/kyleconroy/sqlc/internal/info" "github.com/kyleconroy/sqlc/internal/plugin" @@ -49,6 +50,8 @@ type Runner struct { SHA256 string } +var flight singleflight.Group + // Verify the provided sha256 is valid. func (r *Runner) parseChecksum() (string, error) { if r.SHA256 == "" { @@ -58,6 +61,24 @@ func (r *Runner) parseChecksum() (string, error) { } func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) { + expected, err := r.parseChecksum() + if err != nil { + return nil, err + } + value, err, _ := flight.Do(expected, func() (interface{}, error) { + return r.loadSerializedModule(ctx, engine) + }) + if err != nil { + return nil, err + } + data, ok := value.([]byte) + if !ok { + return nil, fmt.Errorf("returned value was not a byte slice") + } + return wasmtime.NewModuleDeserialize(engine, data) +} + +func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) { expected, err := r.parseChecksum() if err != nil { return nil, err @@ -80,7 +101,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm if err != nil { return nil, err } - return wasmtime.NewModuleDeserialize(engine, data) + return data, nil } wmod, err := r.loadWASM(ctx, cache, expected) @@ -95,21 +116,19 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm return nil, fmt.Errorf("define wasi: %w", err) } - if staterr != nil { - err := os.Mkdir(pluginDir, 0755) - if err != nil && !os.IsExist(err) { - return nil, fmt.Errorf("mkdirall: %w", err) - } - out, err := module.Serialize() - if err != nil { - return nil, fmt.Errorf("serialize: %w", err) - } - if err := os.WriteFile(modPath, out, 0444); err != nil { - return nil, fmt.Errorf("cache wasm: %w", err) - } + err = os.Mkdir(pluginDir, 0755) + if err != nil && !os.IsExist(err) { + return nil, fmt.Errorf("mkdirall: %w", err) + } + out, err := module.Serialize() + if err != nil { + return nil, fmt.Errorf("serialize: %w", err) + } + if err := os.WriteFile(modPath, out, 0444); err != nil { + return nil, fmt.Errorf("cache wasm: %w", err) } - return module, nil + return out, nil } func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {