From f04fd446c44e67c353c27dd5486e23eaa3c47af8 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 26 Dec 2023 14:43:11 +0900 Subject: [PATCH 1/3] Only compile wasm once per process --- internal/ext/wasm/runner.go | 5 +++ internal/ext/wasm/wasm.go | 78 +++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/internal/ext/wasm/runner.go b/internal/ext/wasm/runner.go index 352d827f9f..04e90a15d3 100644 --- a/internal/ext/wasm/runner.go +++ b/internal/ext/wasm/runner.go @@ -1,7 +1,12 @@ package wasm +import "github.com/tetratelabs/wazero" + type Runner struct { URL string SHA256 string Env []string + + rt wazero.Runtime + code wazero.CompiledModule } diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index eddd4eb9a9..c0e2451429 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -29,6 +29,15 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) +func NewRunner(url string, checksum string, env []string) *Runner { + return &Runner{ + URL: url, + SHA256: checksum, + Env: env, + rt: wazero.NewRuntime(context.Background()), + } +} + func cacheDir() (string, error) { cache := os.Getenv("SQLCCACHE") if cache != "" { @@ -61,7 +70,7 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { +func (r *Runner) loadAndCompile(ctx context.Context) (wazero.CompiledModule, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err @@ -71,14 +80,14 @@ func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { return nil, err } value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadWASM(ctx, cacheDir, expected) + return r.loadAndCompileWASM(ctx, cacheDir, expected) }) if err != nil { return nil, err } - data, ok := value.([]byte) + data, ok := value.(wazero.CompiledModule) if !ok { - return nil, fmt.Errorf("returned value was not a byte slice") + return nil, fmt.Errorf("returned value was not a compiled module") } return data, nil } @@ -124,7 +133,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) return wmod, actual, nil } -func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { +func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (wazero.CompiledModule, error) { pluginDir := filepath.Join(cache, expected) pluginPath := filepath.Join(pluginDir, "plugin.wasm") _, staterr := os.Stat(pluginPath) @@ -153,7 +162,22 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([ } } - return wmod, nil + wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero")) + if err != nil { + return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) + } + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) + r.rt = wazero.NewRuntimeWithConfig(ctx, config) + // TODO: Handle error + wasi_snapshot_preview1.MustInstantiate(ctx, r.rt) + + // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. + code, err := r.rt.CompileModule(ctx, wmod) + if err != nil { + return nil, fmt.Errorf("compile module: %w", err) + } + + return code, nil } // removePGCatalog removes the pg_catalog schema from the request. There is a @@ -195,47 +219,25 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - cacheDir, err := cache.PluginsDir() - if err != nil { - return err - } - - cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero")) - if err != nil { - return err - } - - wasmBytes, err := r.loadBytes(ctx) - if err != nil { - return fmt.Errorf("loadModule: %w", err) - } - - config := wazero.NewRuntimeConfig().WithCompilationCache(cache) - rt := wazero.NewRuntimeWithConfig(ctx, config) - defer rt.Close(ctx) - - // TODO: Handle error - wasi_snapshot_preview1.MustInstantiate(ctx, rt) - - // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. - mod, err := rt.CompileModule(ctx, wasmBytes) + wasmCompiled, err := r.loadAndCompile(ctx) if err != nil { - return err + return fmt.Errorf("loadBytes: %w", err) } var stderr, stdout bytes.Buffer - conf := wazero.NewModuleConfig() - conf = conf.WithArgs("plugin.wasm", method) - conf = conf.WithEnv("SQLC_VERSION", info.Version) + conf := wazero.NewModuleConfig(). + WithName(""). + WithArgs("plugin.wasm", method). + WithStdin(bytes.NewReader(stdinBlob)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version) for _, key := range r.Env { conf = conf.WithEnv(key, os.Getenv(key)) } - conf = conf.WithStdin(bytes.NewReader(stdinBlob)) - conf = conf.WithStdout(&stdout) - conf = conf.WithStderr(&stderr) - result, err := rt.InstantiateModule(ctx, mod, conf) + result, err := r.rt.InstantiateModule(ctx, wasmCompiled, conf) if result != nil { defer result.Close(ctx) } From 1948ec963e179c8ce55e034bff4d0bcc1a61c990 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 26 Dec 2023 14:44:43 +0900 Subject: [PATCH 2/3] Remove unused --- internal/ext/wasm/runner.go | 3 +-- internal/ext/wasm/wasm.go | 25 ------------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/internal/ext/wasm/runner.go b/internal/ext/wasm/runner.go index 04e90a15d3..e03e88c536 100644 --- a/internal/ext/wasm/runner.go +++ b/internal/ext/wasm/runner.go @@ -7,6 +7,5 @@ type Runner struct { SHA256 string Env []string - rt wazero.Runtime - code wazero.CompiledModule + rt wazero.Runtime } diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index c0e2451429..eea1c79a44 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -29,31 +29,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func NewRunner(url string, checksum string, env []string) *Runner { - return &Runner{ - URL: url, - SHA256: checksum, - Env: env, - rt: wazero.NewRuntime(context.Background()), - } -} - -func cacheDir() (string, error) { - cache := os.Getenv("SQLCCACHE") - if cache != "" { - return cache, nil - } - cacheHome := os.Getenv("XDG_CACHE_HOME") - if cacheHome == "" { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - cacheHome = filepath.Join(home, ".cache") - } - return filepath.Join(cacheHome, "sqlc"), nil -} - var flight singleflight.Group // Verify the provided sha256 is valid. From 3d71c2151db5884fba2a81833d5d951c65987c80 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 26 Dec 2023 14:51:54 +0900 Subject: [PATCH 3/3] Store runtime in flightgroup as well --- internal/ext/wasm/runner.go | 4 ---- internal/ext/wasm/wasm.go | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/internal/ext/wasm/runner.go b/internal/ext/wasm/runner.go index e03e88c536..352d827f9f 100644 --- a/internal/ext/wasm/runner.go +++ b/internal/ext/wasm/runner.go @@ -1,11 +1,7 @@ package wasm -import "github.com/tetratelabs/wazero" - type Runner struct { URL string SHA256 string Env []string - - rt wazero.Runtime } diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index eea1c79a44..5e838c81a7 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -31,6 +31,11 @@ import ( var flight singleflight.Group +type runtimeAndCode struct { + rt wazero.Runtime + code wazero.CompiledModule +} + // Verify the provided sha256 is valid. func (r *Runner) getChecksum(ctx context.Context) (string, error) { if r.SHA256 != "" { @@ -45,7 +50,7 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadAndCompile(ctx context.Context) (wazero.CompiledModule, error) { +func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err @@ -60,7 +65,7 @@ func (r *Runner) loadAndCompile(ctx context.Context) (wazero.CompiledModule, err if err != nil { return nil, err } - data, ok := value.(wazero.CompiledModule) + data, ok := value.(*runtimeAndCode) if !ok { return nil, fmt.Errorf("returned value was not a compiled module") } @@ -108,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) return wmod, actual, nil } -func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (wazero.CompiledModule, error) { +func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) { pluginDir := filepath.Join(cache, expected) pluginPath := filepath.Join(pluginDir, "plugin.wasm") _, staterr := os.Stat(pluginPath) @@ -142,17 +147,17 @@ func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) } config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) - r.rt = wazero.NewRuntimeWithConfig(ctx, config) + rt := wazero.NewRuntimeWithConfig(ctx, config) // TODO: Handle error - wasi_snapshot_preview1.MustInstantiate(ctx, r.rt) + wasi_snapshot_preview1.MustInstantiate(ctx, rt) // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. - code, err := r.rt.CompileModule(ctx, wmod) + code, err := rt.CompileModule(ctx, wmod) if err != nil { return nil, fmt.Errorf("compile module: %w", err) } - return code, nil + return &runtimeAndCode{rt: rt, code: code}, nil } // removePGCatalog removes the pg_catalog schema from the request. There is a @@ -194,7 +199,7 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - wasmCompiled, err := r.loadAndCompile(ctx) + runtimeAndCode, err := r.loadAndCompile(ctx) if err != nil { return fmt.Errorf("loadBytes: %w", err) } @@ -212,7 +217,7 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, conf = conf.WithEnv(key, os.Getenv(key)) } - result, err := r.rt.InstantiateModule(ctx, wasmCompiled, conf) + result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf) if result != nil { defer result.Close(ctx) }