Skip to content

Commit 5e3d938

Browse files
committed
feat(plugins): Use wazero instead of wasmtime
1 parent eb8d97f commit 5e3d938

File tree

3 files changed

+49
-106
lines changed

3 files changed

+49
-106
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ go 1.21
44

55
require (
66
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1
7-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
87
github.com/cubicdaiya/gonp v1.0.4
98
github.com/davecgh/go-spew v1.1.1
109
github.com/fatih/structtag v1.2.0
@@ -20,6 +19,7 @@ require (
2019
github.com/riza-io/grpc-go v0.2.0
2120
github.com/spf13/cobra v1.8.0
2221
github.com/spf13/pflag v1.0.5
22+
github.com/tetratelabs/wazero v1.5.0
2323
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e
2424
github.com/xeipuuv/gojsonschema v1.2.0
2525
golang.org/x/sync v0.5.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230321174746-8dcc6526cfb1/g
55
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
66
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
77
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
8-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
9-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
108
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
119
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
1210
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
@@ -185,6 +183,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
185183
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
186184
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
187185
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
186+
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
187+
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
188188
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
189189
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
190190
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=

internal/ext/wasm/wasm.go

Lines changed: 46 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package wasm
66

77
import (
8+
"bytes"
89
"context"
910
"crypto/sha256"
1011
"errors"
@@ -15,10 +16,11 @@ import (
1516
"os"
1617
"path/filepath"
1718
"runtime"
18-
"runtime/trace"
1919
"strings"
2020

21-
wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
21+
"github.com/tetratelabs/wazero"
22+
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
23+
"github.com/tetratelabs/wazero/sys"
2224
"golang.org/x/sync/singleflight"
2325
"google.golang.org/grpc"
2426
"google.golang.org/grpc/codes"
@@ -70,13 +72,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
7072
return sum, nil
7173
}
7274

73-
func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
75+
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
7476
expected, err := r.getChecksum(ctx)
7577
if err != nil {
7678
return nil, err
7779
}
80+
cacheDir, err := cache.PluginsDir()
81+
if err != nil {
82+
return nil, err
83+
}
7884
value, err, _ := flight.Do(expected, func() (interface{}, error) {
79-
return r.loadSerializedModule(ctx, engine, expected)
85+
return r.loadWASM(ctx, cacheDir, expected)
8086
})
8187
if err != nil {
8288
return nil, err
@@ -85,52 +91,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
8591
if !ok {
8692
return nil, fmt.Errorf("returned value was not a byte slice")
8793
}
88-
return wasmtime.NewModuleDeserialize(engine, data)
89-
}
90-
91-
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
92-
cacheDir, err := cache.PluginsDir()
93-
if err != nil {
94-
return nil, err
95-
}
96-
97-
pluginDir := filepath.Join(cacheDir, expectedSha)
98-
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
99-
modPath := filepath.Join(pluginDir, modName)
100-
_, staterr := os.Stat(modPath)
101-
if staterr == nil {
102-
data, err := os.ReadFile(modPath)
103-
if err != nil {
104-
return nil, err
105-
}
106-
return data, nil
107-
}
108-
109-
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
110-
if err != nil {
111-
return nil, err
112-
}
113-
114-
moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
115-
module, err := wasmtime.NewModule(engine, wmod)
116-
moduRegion.End()
117-
if err != nil {
118-
return nil, fmt.Errorf("define wasi: %w", err)
119-
}
120-
121-
err = os.Mkdir(pluginDir, 0755)
122-
if err != nil && !os.IsExist(err) {
123-
return nil, fmt.Errorf("mkdirall: %w", err)
124-
}
125-
out, err := module.Serialize()
126-
if err != nil {
127-
return nil, fmt.Errorf("serialize: %w", err)
128-
}
129-
if err := os.WriteFile(modPath, out, 0444); err != nil {
130-
return nil, fmt.Errorf("cache wasm: %w", err)
131-
}
132-
133-
return out, nil
94+
return data, nil
13495
}
13596

13697
func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
@@ -245,72 +206,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
245206
return fmt.Errorf("failed to encode codegen request: %w", err)
246207
}
247208

248-
engine := wasmtime.NewEngine()
249-
module, err := r.loadModule(ctx, engine)
209+
cacheDir, err := cache.PluginsDir()
250210
if err != nil {
251-
return fmt.Errorf("loadModule: %w", err)
211+
return err
252212
}
253213

254-
linker := wasmtime.NewLinker(engine)
255-
if err := linker.DefineWasi(); err != nil {
214+
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero"))
215+
if err != nil {
256216
return err
257217
}
258218

259-
dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
219+
wasmBytes, err := r.loadBytes(ctx)
260220
if err != nil {
261-
return fmt.Errorf("temp dir: %w", err)
221+
return fmt.Errorf("loadModule: %w", err)
262222
}
263223

264-
defer os.RemoveAll(dir)
265-
stdinPath := filepath.Join(dir, "stdin")
266-
stderrPath := filepath.Join(dir, "stderr")
267-
stdoutPath := filepath.Join(dir, "stdout")
224+
config := wazero.NewRuntimeConfig().WithCompilationCache(cache)
225+
rt := wazero.NewRuntimeWithConfig(ctx, config)
226+
defer rt.Close(ctx)
268227

269-
if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
270-
return fmt.Errorf("write file: %w", err)
271-
}
272-
273-
// Configure WASI imports to write stdout into a file.
274-
wasiConfig := wasmtime.NewWasiConfig()
275-
wasiConfig.SetArgv([]string{"plugin.wasm", method})
276-
wasiConfig.SetStdinFile(stdinPath)
277-
wasiConfig.SetStdoutFile(stdoutPath)
278-
wasiConfig.SetStderrFile(stderrPath)
228+
// TODO: Handle error
229+
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
279230

280-
keys := []string{"SQLC_VERSION"}
281-
vals := []string{info.Version}
282-
for _, key := range r.Env {
283-
keys = append(keys, key)
284-
vals = append(vals, os.Getenv(key))
231+
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
232+
mod, err := rt.CompileModule(ctx, wasmBytes)
233+
if err != nil {
234+
return err
285235
}
286-
wasiConfig.SetEnv(keys, vals)
287236

288-
store := wasmtime.NewStore(engine)
289-
store.SetWasi(wasiConfig)
237+
var stderr, stdout bytes.Buffer
290238

291-
linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
292-
err = linker.DefineModule(store, "", module)
293-
linkRegion.End()
294-
if err != nil {
295-
return fmt.Errorf("define wasi: %w", err)
239+
conf := wazero.NewModuleConfig()
240+
conf = conf.WithArgs("plugin.wasm", method)
241+
conf = conf.WithEnv("SQLC_VERSION", info.Version)
242+
for _, key := range r.Env {
243+
conf = conf.WithEnv(key, os.Getenv(key))
296244
}
245+
conf = conf.WithStdin(bytes.NewReader(stdinBlob))
246+
conf = conf.WithStdout(&stdout)
247+
conf = conf.WithStderr(&stderr)
297248

298-
// Run the function
299-
fn, err := linker.GetDefault(store, "")
300-
if err != nil {
301-
return fmt.Errorf("wasi: get default: %w", err)
249+
result, err := rt.InstantiateModule(ctx, mod, conf)
250+
if result != nil {
251+
defer result.Close(ctx)
302252
}
303-
304-
callRegion := trace.StartRegion(ctx, "call _start")
305-
_, err = fn.Call(store)
306-
callRegion.End()
307-
308-
if cerr := checkError(err, stderrPath); cerr != nil {
253+
if cerr := checkError(err, &stderr); cerr != nil {
309254
return cerr
310255
}
311256

312257
// Print WASM stdout
313-
stdoutBlob, err := os.ReadFile(stdoutPath)
258+
stdoutBlob, err := io.ReadAll(&stdout)
314259
if err != nil {
315260
return fmt.Errorf("read file: %w", err)
316261
}
@@ -331,21 +276,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
331276
return nil, status.Error(codes.Unimplemented, "")
332277
}
333278

334-
func checkError(err error, stderrPath string) error {
279+
func checkError(err error, stderr io.Reader) error {
335280
if err == nil {
336281
return err
337282
}
338283

339-
var wtError *wasmtime.Error
340-
if errors.As(err, &wtError) {
341-
if code, ok := wtError.ExitStatus(); ok {
342-
if code == 0 {
343-
return nil
344-
}
284+
if exitErr, ok := err.(*sys.ExitError); ok {
285+
if exitErr.ExitCode() == 0 {
286+
return nil
345287
}
346288
}
289+
347290
// Print WASM stdout
348-
stderrBlob, rferr := os.ReadFile(stderrPath)
291+
stderrBlob, rferr := io.ReadAll(stderr)
349292
if rferr == nil && len(stderrBlob) > 0 {
350293
return errors.New(string(stderrBlob))
351294
}

0 commit comments

Comments
 (0)