From 5e3d938a3faba144b6de71eb8c446bef9290d1d9 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 5 Dec 2023 15:03:45 -0800 Subject: [PATCH 1/7] feat(plugins): Use wazero instead of wasmtime --- go.mod | 2 +- go.sum | 4 +- internal/ext/wasm/wasm.go | 149 ++++++++++++-------------------------- 3 files changed, 49 insertions(+), 106 deletions(-) diff --git a/go.mod b/go.mod index 2ac600c4a7..5f5eb30a62 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21 require ( github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1 - github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 github.com/cubicdaiya/gonp v1.0.4 github.com/davecgh/go-spew v1.1.1 github.com/fatih/structtag v1.2.0 @@ -20,6 +19,7 @@ require ( github.com/riza-io/grpc-go v0.2.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 + github.com/tetratelabs/wazero v1.5.0 github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/sync v0.5.0 diff --git a/go.sum b/go.sum index f3329e8f94..13fb24260b 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,6 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/g github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM= -github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -185,6 +183,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0= +github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A= github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk= github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE= github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU= diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index c096ec9844..fe42c33b0e 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -5,6 +5,7 @@ package wasm import ( + "bytes" "context" "crypto/sha256" "errors" @@ -15,10 +16,11 @@ import ( "os" "path/filepath" "runtime" - "runtime/trace" "strings" - wasmtime "github.com/bytecodealliance/wasmtime-go/v14" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/tetratelabs/wazero/sys" "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -70,13 +72,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) { +func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err } + cacheDir, err := cache.PluginsDir() + if err != nil { + return nil, err + } value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadSerializedModule(ctx, engine, expected) + return r.loadWASM(ctx, cacheDir, expected) }) if err != nil { return nil, err @@ -85,52 +91,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm if !ok { return nil, fmt.Errorf("returned value was not a byte slice") } - return wasmtime.NewModuleDeserialize(engine, data) -} - -func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) { - cacheDir, err := cache.PluginsDir() - if err != nil { - return nil, err - } - - pluginDir := filepath.Join(cacheDir, expectedSha) - modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion) - modPath := filepath.Join(pluginDir, modName) - _, staterr := os.Stat(modPath) - if staterr == nil { - data, err := os.ReadFile(modPath) - if err != nil { - return nil, err - } - return data, nil - } - - wmod, err := r.loadWASM(ctx, cacheDir, expectedSha) - if err != nil { - return nil, err - } - - moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule") - module, err := wasmtime.NewModule(engine, wmod) - moduRegion.End() - if err != nil { - return nil, fmt.Errorf("define wasi: %w", err) - } - - err = os.Mkdir(pluginDir, 0755) - if err != nil && !os.IsExist(err) { - return nil, fmt.Errorf("mkdirall: %w", err) - } - out, err := module.Serialize() - if err != nil { - return nil, fmt.Errorf("serialize: %w", err) - } - if err := os.WriteFile(modPath, out, 0444); err != nil { - return nil, fmt.Errorf("cache wasm: %w", err) - } - - return out, nil + return data, nil } func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) { @@ -245,72 +206,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - engine := wasmtime.NewEngine() - module, err := r.loadModule(ctx, engine) + cacheDir, err := cache.PluginsDir() if err != nil { - return fmt.Errorf("loadModule: %w", err) + return err } - linker := wasmtime.NewLinker(engine) - if err := linker.DefineWasi(); err != nil { + cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero")) + if err != nil { return err } - dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out") + wasmBytes, err := r.loadBytes(ctx) if err != nil { - return fmt.Errorf("temp dir: %w", err) + return fmt.Errorf("loadModule: %w", err) } - defer os.RemoveAll(dir) - stdinPath := filepath.Join(dir, "stdin") - stderrPath := filepath.Join(dir, "stderr") - stdoutPath := filepath.Join(dir, "stdout") + config := wazero.NewRuntimeConfig().WithCompilationCache(cache) + rt := wazero.NewRuntimeWithConfig(ctx, config) + defer rt.Close(ctx) - if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil { - return fmt.Errorf("write file: %w", err) - } - - // Configure WASI imports to write stdout into a file. - wasiConfig := wasmtime.NewWasiConfig() - wasiConfig.SetArgv([]string{"plugin.wasm", method}) - wasiConfig.SetStdinFile(stdinPath) - wasiConfig.SetStdoutFile(stdoutPath) - wasiConfig.SetStderrFile(stderrPath) + // TODO: Handle error + wasi_snapshot_preview1.MustInstantiate(ctx, rt) - keys := []string{"SQLC_VERSION"} - vals := []string{info.Version} - for _, key := range r.Env { - keys = append(keys, key) - vals = append(vals, os.Getenv(key)) + // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. + mod, err := rt.CompileModule(ctx, wasmBytes) + if err != nil { + return err } - wasiConfig.SetEnv(keys, vals) - store := wasmtime.NewStore(engine) - store.SetWasi(wasiConfig) + var stderr, stdout bytes.Buffer - linkRegion := trace.StartRegion(ctx, "linker.DefineModule") - err = linker.DefineModule(store, "", module) - linkRegion.End() - if err != nil { - return fmt.Errorf("define wasi: %w", err) + conf := wazero.NewModuleConfig() + conf = conf.WithArgs("plugin.wasm", method) + conf = conf.WithEnv("SQLC_VERSION", info.Version) + for _, key := range r.Env { + conf = conf.WithEnv(key, os.Getenv(key)) } + conf = conf.WithStdin(bytes.NewReader(stdinBlob)) + conf = conf.WithStdout(&stdout) + conf = conf.WithStderr(&stderr) - // Run the function - fn, err := linker.GetDefault(store, "") - if err != nil { - return fmt.Errorf("wasi: get default: %w", err) + result, err := rt.InstantiateModule(ctx, mod, conf) + if result != nil { + defer result.Close(ctx) } - - callRegion := trace.StartRegion(ctx, "call _start") - _, err = fn.Call(store) - callRegion.End() - - if cerr := checkError(err, stderrPath); cerr != nil { + if cerr := checkError(err, &stderr); cerr != nil { return cerr } // Print WASM stdout - stdoutBlob, err := os.ReadFile(stdoutPath) + stdoutBlob, err := io.ReadAll(&stdout) if err != nil { return fmt.Errorf("read file: %w", err) } @@ -331,21 +276,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st return nil, status.Error(codes.Unimplemented, "") } -func checkError(err error, stderrPath string) error { +func checkError(err error, stderr io.Reader) error { if err == nil { return err } - var wtError *wasmtime.Error - if errors.As(err, &wtError) { - if code, ok := wtError.ExitStatus(); ok { - if code == 0 { - return nil - } + if exitErr, ok := err.(*sys.ExitError); ok { + if exitErr.ExitCode() == 0 { + return nil } } + // Print WASM stdout - stderrBlob, rferr := os.ReadFile(stderrPath) + stderrBlob, rferr := io.ReadAll(stderr) if rferr == nil && len(stderrBlob) > 0 { return errors.New(string(stderrBlob)) } From 2878527f82036821c020f5f6c7238c483d3b463d Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 5 Dec 2023 15:11:25 -0800 Subject: [PATCH 2/7] Remove wasm build tags --- internal/endtoend/case_test.go | 1 - internal/endtoend/endtoend_test.go | 5 ---- .../wasm_plugin_sqlc_gen_greeter/exec.json | 3 --- .../wasm_plugin_sqlc_gen_test/exec.json | 3 --- internal/ext/wasm/nowasm.go | 23 ------------------- internal/ext/wasm/wasm.go | 11 --------- 6 files changed, 46 deletions(-) delete mode 100644 internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json delete mode 100644 internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json delete mode 100644 internal/ext/wasm/nowasm.go diff --git a/internal/endtoend/case_test.go b/internal/endtoend/case_test.go index 208b3fb9fa..50dcc57ec5 100644 --- a/internal/endtoend/case_test.go +++ b/internal/endtoend/case_test.go @@ -22,7 +22,6 @@ type Exec struct { Contexts []string `json:"contexts"` Process string `json:"process"` OS []string `json:"os"` - WASM bool `json:"wasm"` Env map[string]string `json:"env"` } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 2054baeee3..5753ce6d3a 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -16,7 +16,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/cmd" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/ext/wasm" "github.com/sqlc-dev/sqlc/internal/opts" ) @@ -177,10 +176,6 @@ func TestReplay(t *testing.T) { } } - if args.WASM && !wasm.Enabled() { - t.Skipf("wasm support not enabled") - } - if len(args.OS) > 0 { if !slices.Contains(args.OS, runtime.GOOS) { t.Skipf("unsupported os: %s", runtime.GOOS) diff --git a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json b/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json deleted file mode 100644 index efe8bbc9aa..0000000000 --- a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "wasm": true -} diff --git a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json b/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json deleted file mode 100644 index efe8bbc9aa..0000000000 --- a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "wasm": true -} diff --git a/internal/ext/wasm/nowasm.go b/internal/ext/wasm/nowasm.go deleted file mode 100644 index 14af0b54a2..0000000000 --- a/internal/ext/wasm/nowasm.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build nowasm || !(cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))) - -package wasm - -import ( - "context" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func Enabled() bool { - return false -} - -func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { - return status.Error(codes.FailedPrecondition, "sqlc built without wasmtime support") -} - -func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - return nil, status.Error(codes.Unimplemented, codes.Unimplemented.String()) -} diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index fe42c33b0e..a910a8afd9 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -1,7 +1,3 @@ -//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) - -// The above build constraint is based of the cgo directives in this file: -// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go package wasm import ( @@ -33,13 +29,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func Enabled() bool { - return true -} - -// This version must be updated whenever the wasmtime-go dependency is updated -const wasmtimeVersion = `v14.0.0` - func cacheDir() (string, error) { cache := os.Getenv("SQLCCACHE") if cache != "" { From 7517d032de6f04dbba8475be20c702ee9e9cce86 Mon Sep 17 00:00:00 2001 From: Kyle Gray Date: Tue, 5 Dec 2023 18:09:53 -0800 Subject: [PATCH 3/7] Update internal/ext/wasm/wasm.go Co-authored-by: Anuraag Agrawal --- internal/ext/wasm/wasm.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index a910a8afd9..ba95b12aeb 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -277,9 +277,9 @@ func checkError(err error, stderr io.Reader) error { } // Print WASM stdout - stderrBlob, rferr := io.ReadAll(stderr) - if rferr == nil && len(stderrBlob) > 0 { - return errors.New(string(stderrBlob)) + stderrBlob := stderr.String() + if len(stderrBlob) > 0 { + return errors.New(stderrBlob) } return fmt.Errorf("call: %w", err) } From 89f4d8b3b5c20b074803a54088c1c827a16e8413 Mon Sep 17 00:00:00 2001 From: Kyle Gray Date: Tue, 5 Dec 2023 18:09:58 -0800 Subject: [PATCH 4/7] Update internal/ext/wasm/wasm.go Co-authored-by: Anuraag Agrawal --- internal/ext/wasm/wasm.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index ba95b12aeb..28afc7c887 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -244,10 +244,7 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, } // Print WASM stdout - stdoutBlob, err := io.ReadAll(&stdout) - if err != nil { - return fmt.Errorf("read file: %w", err) - } + stdoutBlob := stdout.Bytes() resp, ok := reply.(protoreflect.ProtoMessage) if !ok { From 0914cad96baea8c89843de08b7ab3da77f627e88 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 5 Dec 2023 18:13:18 -0800 Subject: [PATCH 5/7] Fix build --- internal/ext/wasm/wasm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index 28afc7c887..eddd4eb9a9 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -239,7 +239,7 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, if result != nil { defer result.Close(ctx) } - if cerr := checkError(err, &stderr); cerr != nil { + if cerr := checkError(err, stderr); cerr != nil { return cerr } @@ -262,7 +262,7 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st return nil, status.Error(codes.Unimplemented, "") } -func checkError(err error, stderr io.Reader) error { +func checkError(err error, stderr bytes.Buffer) error { if err == nil { return err } From 245563de5913c57e4b932fbd767ada79339f6a66 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Wed, 3 Jan 2024 02:58:02 +0900 Subject: [PATCH 6/7] Suggestions for PR #3042 (#3082) * Only compile wasm once per process * Remove unused * Store runtime in flightgroup as well --- internal/ext/wasm/wasm.go | 90 ++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 54 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index eddd4eb9a9..5e838c81a7 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -29,24 +29,13 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func cacheDir() (string, error) { - cache := os.Getenv("SQLCCACHE") - if cache != "" { - return cache, nil - } - cacheHome := os.Getenv("XDG_CACHE_HOME") - if cacheHome == "" { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - cacheHome = filepath.Join(home, ".cache") - } - return filepath.Join(cacheHome, "sqlc"), nil -} - var flight singleflight.Group +type runtimeAndCode struct { + rt wazero.Runtime + code wazero.CompiledModule +} + // Verify the provided sha256 is valid. func (r *Runner) getChecksum(ctx context.Context) (string, error) { if r.SHA256 != "" { @@ -61,7 +50,7 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { +func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err @@ -71,14 +60,14 @@ func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { return nil, err } value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadWASM(ctx, cacheDir, expected) + return r.loadAndCompileWASM(ctx, cacheDir, expected) }) if err != nil { return nil, err } - data, ok := value.([]byte) + data, ok := value.(*runtimeAndCode) if !ok { - return nil, fmt.Errorf("returned value was not a byte slice") + return nil, fmt.Errorf("returned value was not a compiled module") } return data, nil } @@ -124,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) return wmod, actual, nil } -func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { +func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) { pluginDir := filepath.Join(cache, expected) pluginPath := filepath.Join(pluginDir, "plugin.wasm") _, staterr := os.Stat(pluginPath) @@ -153,7 +142,22 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([ } } - return wmod, nil + wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero")) + if err != nil { + return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) + } + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) + rt := wazero.NewRuntimeWithConfig(ctx, config) + // TODO: Handle error + wasi_snapshot_preview1.MustInstantiate(ctx, rt) + + // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. + code, err := rt.CompileModule(ctx, wmod) + if err != nil { + return nil, fmt.Errorf("compile module: %w", err) + } + + return &runtimeAndCode{rt: rt, code: code}, nil } // removePGCatalog removes the pg_catalog schema from the request. There is a @@ -195,47 +199,25 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - cacheDir, err := cache.PluginsDir() + runtimeAndCode, err := r.loadAndCompile(ctx) if err != nil { - return err - } - - cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero")) - if err != nil { - return err - } - - wasmBytes, err := r.loadBytes(ctx) - if err != nil { - return fmt.Errorf("loadModule: %w", err) - } - - config := wazero.NewRuntimeConfig().WithCompilationCache(cache) - rt := wazero.NewRuntimeWithConfig(ctx, config) - defer rt.Close(ctx) - - // TODO: Handle error - wasi_snapshot_preview1.MustInstantiate(ctx, rt) - - // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. - mod, err := rt.CompileModule(ctx, wasmBytes) - if err != nil { - return err + return fmt.Errorf("loadBytes: %w", err) } var stderr, stdout bytes.Buffer - conf := wazero.NewModuleConfig() - conf = conf.WithArgs("plugin.wasm", method) - conf = conf.WithEnv("SQLC_VERSION", info.Version) + conf := wazero.NewModuleConfig(). + WithName(""). + WithArgs("plugin.wasm", method). + WithStdin(bytes.NewReader(stdinBlob)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version) for _, key := range r.Env { conf = conf.WithEnv(key, os.Getenv(key)) } - conf = conf.WithStdin(bytes.NewReader(stdinBlob)) - conf = conf.WithStdout(&stdout) - conf = conf.WithStderr(&stderr) - result, err := rt.InstantiateModule(ctx, mod, conf) + result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf) if result != nil { defer result.Close(ctx) } From 18fbb557167bd9907677bb8c197278ec135c690f Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Tue, 2 Jan 2024 10:18:18 -0800 Subject: [PATCH 7/7] Handle error from instantiate --- internal/ext/wasm/wasm.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index 5e838c81a7..a14c71d8a4 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -146,12 +146,16 @@ func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected if err != nil { return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) } + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) rt := wazero.NewRuntimeWithConfig(ctx, config) - // TODO: Handle error - wasi_snapshot_preview1.MustInstantiate(ctx, rt) - // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. + if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil { + return nil, fmt.Errorf("wasi_snapshot_preview1 instantiate: %w", err) + } + + // Compile the Wasm binary once so that we can skip the entire compilation + // time during instantiation. code, err := rt.CompileModule(ctx, wmod) if err != nil { return nil, fmt.Errorf("compile module: %w", err)