diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index 1d04ea176a..ee4b1841d7 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -258,25 +258,25 @@ func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plu store := wasmtime.NewStore(engine) store.SetWasi(wasiConfig) - linkRegion := trace.StartRegion(ctx, "linker.Instantiate") - instance, err := linker.Instantiate(store, module) + linkRegion := trace.StartRegion(ctx, "linker.DefineModule") + err = linker.DefineModule(store, "", module) linkRegion.End() if err != nil { return nil, fmt.Errorf("define wasi: %w", err) } // Run the function + fn, err := linker.GetDefault(store, "") + if err != nil { + return nil, fmt.Errorf("wasi: get default: %w", err) + } + callRegion := trace.StartRegion(ctx, "call _start") - nom := instance.GetExport(store, "_start").Func() - _, err = nom.Call(store) + _, err = fn.Call(store) callRegion.End() - if err != nil { - // Print WASM stdout - stderrBlob, err := os.ReadFile(stderrPath) - if err == nil && len(stderrBlob) > 0 { - return nil, errors.New(string(stderrBlob)) - } - return nil, fmt.Errorf("call: %w", err) + + if cerr := checkError(err, stderrPath); cerr != nil { + return nil, cerr } // Print WASM stdout @@ -284,6 +284,28 @@ func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plu if err != nil { return nil, fmt.Errorf("read file: %w", err) } + var resp plugin.CodeGenResponse return &resp, resp.UnmarshalVT(stdoutBlob) } + +func checkError(err error, stderrPath string) error { + if err == nil { + return err + } + + var wtError *wasmtime.Error + if errors.As(err, &wtError) { + if code, ok := wtError.ExitStatus(); ok { + if code == 0 { + return nil + } + } + } + // Print WASM stdout + stderrBlob, rferr := os.ReadFile(stderrPath) + if rferr == nil && len(stderrBlob) > 0 { + return errors.New(string(stderrBlob)) + } + return fmt.Errorf("call: %w", err) +}