diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index eddd4eb9a9..5e838c81a7 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -29,24 +29,13 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func cacheDir() (string, error) { - cache := os.Getenv("SQLCCACHE") - if cache != "" { - return cache, nil - } - cacheHome := os.Getenv("XDG_CACHE_HOME") - if cacheHome == "" { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - cacheHome = filepath.Join(home, ".cache") - } - return filepath.Join(cacheHome, "sqlc"), nil -} - var flight singleflight.Group +type runtimeAndCode struct { + rt wazero.Runtime + code wazero.CompiledModule +} + // Verify the provided sha256 is valid. func (r *Runner) getChecksum(ctx context.Context) (string, error) { if r.SHA256 != "" { @@ -61,7 +50,7 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { +func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err @@ -71,14 +60,14 @@ func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { return nil, err } value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadWASM(ctx, cacheDir, expected) + return r.loadAndCompileWASM(ctx, cacheDir, expected) }) if err != nil { return nil, err } - data, ok := value.([]byte) + data, ok := value.(*runtimeAndCode) if !ok { - return nil, fmt.Errorf("returned value was not a byte slice") + return nil, fmt.Errorf("returned value was not a compiled module") } return data, nil } @@ -124,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) return wmod, actual, nil } -func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { +func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) { pluginDir := filepath.Join(cache, expected) pluginPath := filepath.Join(pluginDir, "plugin.wasm") _, staterr := os.Stat(pluginPath) @@ -153,7 +142,22 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([ } } - return wmod, nil + wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero")) + if err != nil { + return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) + } + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) + rt := wazero.NewRuntimeWithConfig(ctx, config) + // TODO: Handle error + wasi_snapshot_preview1.MustInstantiate(ctx, rt) + + // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. + code, err := rt.CompileModule(ctx, wmod) + if err != nil { + return nil, fmt.Errorf("compile module: %w", err) + } + + return &runtimeAndCode{rt: rt, code: code}, nil } // removePGCatalog removes the pg_catalog schema from the request. There is a @@ -195,47 +199,25 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - cacheDir, err := cache.PluginsDir() + runtimeAndCode, err := r.loadAndCompile(ctx) if err != nil { - return err - } - - cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero")) - if err != nil { - return err - } - - wasmBytes, err := r.loadBytes(ctx) - if err != nil { - return fmt.Errorf("loadModule: %w", err) - } - - config := wazero.NewRuntimeConfig().WithCompilationCache(cache) - rt := wazero.NewRuntimeWithConfig(ctx, config) - defer rt.Close(ctx) - - // TODO: Handle error - wasi_snapshot_preview1.MustInstantiate(ctx, rt) - - // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. - mod, err := rt.CompileModule(ctx, wasmBytes) - if err != nil { - return err + return fmt.Errorf("loadBytes: %w", err) } var stderr, stdout bytes.Buffer - conf := wazero.NewModuleConfig() - conf = conf.WithArgs("plugin.wasm", method) - conf = conf.WithEnv("SQLC_VERSION", info.Version) + conf := wazero.NewModuleConfig(). + WithName(""). + WithArgs("plugin.wasm", method). + WithStdin(bytes.NewReader(stdinBlob)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version) for _, key := range r.Env { conf = conf.WithEnv(key, os.Getenv(key)) } - conf = conf.WithStdin(bytes.NewReader(stdinBlob)) - conf = conf.WithStdout(&stdout) - conf = conf.WithStderr(&stderr) - result, err := rt.InstantiateModule(ctx, mod, conf) + result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf) if result != nil { defer result.Close(ctx) }