Skip to content

Commit 4188d23

Browse files
kyleconroyanuraaga
andauthored
feat(plugins): Use wazero instead of wasmtime (#3042)
* feat(plugins): Use wazero instead of wasmtime * Remove wasm build tags * Update internal/ext/wasm/wasm.go Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com> * Update internal/ext/wasm/wasm.go Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com> * Fix build * Suggestions for PR #3042 (#3082) * Only compile wasm once per process * Remove unused * Store runtime in flightgroup as well * Handle error from instantiate --------- Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com>
1 parent 4f7fca7 commit 4188d23

File tree

8 files changed

+68
-184
lines changed

8 files changed

+68
-184
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ go 1.21
44

55
require (
66
github.com/antlr4-go/antlr/v4 v4.13.0
7-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
87
github.com/cubicdaiya/gonp v1.0.4
98
github.com/davecgh/go-spew v1.1.1
109
github.com/fatih/structtag v1.2.0
@@ -20,6 +19,7 @@ require (
2019
github.com/riza-io/grpc-go v0.2.0
2120
github.com/spf13/cobra v1.8.0
2221
github.com/spf13/pflag v1.0.5
22+
github.com/tetratelabs/wazero v1.5.0
2323
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99
2424
github.com/xeipuuv/gojsonschema v1.2.0
2525
golang.org/x/sync v0.5.0

go.sum

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0
33
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
44
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
55
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
6-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
7-
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
86
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
97
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
108
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
@@ -183,6 +181,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
183181
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
184182
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
185183
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
184+
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
185+
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
186+
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
187+
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
188+
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=
189+
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6/go.mod h1:IQNVyA4d1hWIe23mlMMuqXjyWMdndgSlNx6FqBkwPsM=
186190
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99 h1:HFee1ByN4FrqNVd53Mo28ccGO+g5gxqUV/gdvKMe4b8=
187191
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99/go.mod h1:f2JMhFocVxY3VKMd9ykUxMnX4EVew9WOgjnfaNBB6C8=
188192
github.com/wasilibs/wazerox v0.0.0-20231208014050-e6b725634531 h1:zVJ4SZgaEE9sEH2L9k1+eAvCNa/WAAnT9UiMa3/tQrI=

internal/endtoend/case_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ type Exec struct {
2222
Contexts []string `json:"contexts"`
2323
Process string `json:"process"`
2424
OS []string `json:"os"`
25-
WASM bool `json:"wasm"`
2625
Env map[string]string `json:"env"`
2726
}
2827

internal/endtoend/endtoend_test.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616

1717
"github.com/sqlc-dev/sqlc/internal/cmd"
1818
"github.com/sqlc-dev/sqlc/internal/config"
19-
"github.com/sqlc-dev/sqlc/internal/ext/wasm"
2019
"github.com/sqlc-dev/sqlc/internal/opts"
2120
)
2221

@@ -177,10 +176,6 @@ func TestReplay(t *testing.T) {
177176
}
178177
}
179178

180-
if args.WASM && !wasm.Enabled() {
181-
t.Skipf("wasm support not enabled")
182-
}
183-
184179
if len(args.OS) > 0 {
185180
if !slices.Contains(args.OS, runtime.GOOS) {
186181
t.Skipf("unsupported os: %s", runtime.GOOS)

internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json

Lines changed: 0 additions & 3 deletions
This file was deleted.

internal/ext/wasm/nowasm.go

Lines changed: 0 additions & 23 deletions
This file was deleted.

internal/ext/wasm/wasm.go

Lines changed: 61 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))
2-
3-
// The above build constraint is based of the cgo directives in this file:
4-
// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
51
package wasm
62

73
import (
4+
"bytes"
85
"context"
96
"crypto/sha256"
107
"errors"
@@ -15,10 +12,11 @@ import (
1512
"os"
1613
"path/filepath"
1714
"runtime"
18-
"runtime/trace"
1915
"strings"
2016

21-
wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
17+
"github.com/tetratelabs/wazero"
18+
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
19+
"github.com/tetratelabs/wazero/sys"
2220
"golang.org/x/sync/singleflight"
2321
"google.golang.org/grpc"
2422
"google.golang.org/grpc/codes"
@@ -31,31 +29,13 @@ import (
3129
"github.com/sqlc-dev/sqlc/internal/plugin"
3230
)
3331

34-
func Enabled() bool {
35-
return true
36-
}
37-
38-
// This version must be updated whenever the wasmtime-go dependency is updated
39-
const wasmtimeVersion = `v14.0.0`
32+
var flight singleflight.Group
4033

41-
func cacheDir() (string, error) {
42-
cache := os.Getenv("SQLCCACHE")
43-
if cache != "" {
44-
return cache, nil
45-
}
46-
cacheHome := os.Getenv("XDG_CACHE_HOME")
47-
if cacheHome == "" {
48-
home, err := os.UserHomeDir()
49-
if err != nil {
50-
return "", err
51-
}
52-
cacheHome = filepath.Join(home, ".cache")
53-
}
54-
return filepath.Join(cacheHome, "sqlc"), nil
34+
type runtimeAndCode struct {
35+
rt wazero.Runtime
36+
code wazero.CompiledModule
5537
}
5638

57-
var flight singleflight.Group
58-
5939
// Verify the provided sha256 is valid.
6040
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
6141
if r.SHA256 != "" {
@@ -70,67 +50,26 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
7050
return sum, nil
7151
}
7252

73-
func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
53+
func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) {
7454
expected, err := r.getChecksum(ctx)
7555
if err != nil {
7656
return nil, err
7757
}
78-
value, err, _ := flight.Do(expected, func() (interface{}, error) {
79-
return r.loadSerializedModule(ctx, engine, expected)
80-
})
81-
if err != nil {
82-
return nil, err
83-
}
84-
data, ok := value.([]byte)
85-
if !ok {
86-
return nil, fmt.Errorf("returned value was not a byte slice")
87-
}
88-
return wasmtime.NewModuleDeserialize(engine, data)
89-
}
90-
91-
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
9258
cacheDir, err := cache.PluginsDir()
9359
if err != nil {
9460
return nil, err
9561
}
96-
97-
pluginDir := filepath.Join(cacheDir, expectedSha)
98-
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
99-
modPath := filepath.Join(pluginDir, modName)
100-
_, staterr := os.Stat(modPath)
101-
if staterr == nil {
102-
data, err := os.ReadFile(modPath)
103-
if err != nil {
104-
return nil, err
105-
}
106-
return data, nil
107-
}
108-
109-
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
62+
value, err, _ := flight.Do(expected, func() (interface{}, error) {
63+
return r.loadAndCompileWASM(ctx, cacheDir, expected)
64+
})
11065
if err != nil {
11166
return nil, err
11267
}
113-
114-
moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
115-
module, err := wasmtime.NewModule(engine, wmod)
116-
moduRegion.End()
117-
if err != nil {
118-
return nil, fmt.Errorf("define wasi: %w", err)
119-
}
120-
121-
err = os.Mkdir(pluginDir, 0755)
122-
if err != nil && !os.IsExist(err) {
123-
return nil, fmt.Errorf("mkdirall: %w", err)
124-
}
125-
out, err := module.Serialize()
126-
if err != nil {
127-
return nil, fmt.Errorf("serialize: %w", err)
128-
}
129-
if err := os.WriteFile(modPath, out, 0444); err != nil {
130-
return nil, fmt.Errorf("cache wasm: %w", err)
68+
data, ok := value.(*runtimeAndCode)
69+
if !ok {
70+
return nil, fmt.Errorf("returned value was not a compiled module")
13171
}
132-
133-
return out, nil
72+
return data, nil
13473
}
13574

13675
func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
@@ -174,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error)
174113
return wmod, actual, nil
175114
}
176115

177-
func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
116+
func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) {
178117
pluginDir := filepath.Join(cache, expected)
179118
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
180119
_, staterr := os.Stat(pluginPath)
@@ -203,7 +142,26 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([
203142
}
204143
}
205144

206-
return wmod, nil
145+
wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero"))
146+
if err != nil {
147+
return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err)
148+
}
149+
150+
config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache)
151+
rt := wazero.NewRuntimeWithConfig(ctx, config)
152+
153+
if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil {
154+
return nil, fmt.Errorf("wasi_snapshot_preview1 instantiate: %w", err)
155+
}
156+
157+
// Compile the Wasm binary once so that we can skip the entire compilation
158+
// time during instantiation.
159+
code, err := rt.CompileModule(ctx, wmod)
160+
if err != nil {
161+
return nil, fmt.Errorf("compile module: %w", err)
162+
}
163+
164+
return &runtimeAndCode{rt: rt, code: code}, nil
207165
}
208166

209167
// removePGCatalog removes the pg_catalog schema from the request. There is a
@@ -245,75 +203,34 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
245203
return fmt.Errorf("failed to encode codegen request: %w", err)
246204
}
247205

248-
engine := wasmtime.NewEngine()
249-
module, err := r.loadModule(ctx, engine)
250-
if err != nil {
251-
return fmt.Errorf("loadModule: %w", err)
252-
}
253-
254-
linker := wasmtime.NewLinker(engine)
255-
if err := linker.DefineWasi(); err != nil {
256-
return err
257-
}
258-
259-
dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
206+
runtimeAndCode, err := r.loadAndCompile(ctx)
260207
if err != nil {
261-
return fmt.Errorf("temp dir: %w", err)
262-
}
263-
264-
defer os.RemoveAll(dir)
265-
stdinPath := filepath.Join(dir, "stdin")
266-
stderrPath := filepath.Join(dir, "stderr")
267-
stdoutPath := filepath.Join(dir, "stdout")
268-
269-
if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
270-
return fmt.Errorf("write file: %w", err)
208+
return fmt.Errorf("loadBytes: %w", err)
271209
}
272210

273-
// Configure WASI imports to write stdout into a file.
274-
wasiConfig := wasmtime.NewWasiConfig()
275-
wasiConfig.SetArgv([]string{"plugin.wasm", method})
276-
wasiConfig.SetStdinFile(stdinPath)
277-
wasiConfig.SetStdoutFile(stdoutPath)
278-
wasiConfig.SetStderrFile(stderrPath)
211+
var stderr, stdout bytes.Buffer
279212

280-
keys := []string{"SQLC_VERSION"}
281-
vals := []string{info.Version}
213+
conf := wazero.NewModuleConfig().
214+
WithName("").
215+
WithArgs("plugin.wasm", method).
216+
WithStdin(bytes.NewReader(stdinBlob)).
217+
WithStdout(&stdout).
218+
WithStderr(&stderr).
219+
WithEnv("SQLC_VERSION", info.Version)
282220
for _, key := range r.Env {
283-
keys = append(keys, key)
284-
vals = append(vals, os.Getenv(key))
285-
}
286-
wasiConfig.SetEnv(keys, vals)
287-
288-
store := wasmtime.NewStore(engine)
289-
store.SetWasi(wasiConfig)
290-
291-
linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
292-
err = linker.DefineModule(store, "", module)
293-
linkRegion.End()
294-
if err != nil {
295-
return fmt.Errorf("define wasi: %w", err)
221+
conf = conf.WithEnv(key, os.Getenv(key))
296222
}
297223

298-
// Run the function
299-
fn, err := linker.GetDefault(store, "")
300-
if err != nil {
301-
return fmt.Errorf("wasi: get default: %w", err)
224+
result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf)
225+
if result != nil {
226+
defer result.Close(ctx)
302227
}
303-
304-
callRegion := trace.StartRegion(ctx, "call _start")
305-
_, err = fn.Call(store)
306-
callRegion.End()
307-
308-
if cerr := checkError(err, stderrPath); cerr != nil {
228+
if cerr := checkError(err, stderr); cerr != nil {
309229
return cerr
310230
}
311231

312232
// Print WASM stdout
313-
stdoutBlob, err := os.ReadFile(stdoutPath)
314-
if err != nil {
315-
return fmt.Errorf("read file: %w", err)
316-
}
233+
stdoutBlob := stdout.Bytes()
317234

318235
resp, ok := reply.(protoreflect.ProtoMessage)
319236
if !ok {
@@ -331,23 +248,21 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
331248
return nil, status.Error(codes.Unimplemented, "")
332249
}
333250

334-
func checkError(err error, stderrPath string) error {
251+
func checkError(err error, stderr bytes.Buffer) error {
335252
if err == nil {
336253
return err
337254
}
338255

339-
var wtError *wasmtime.Error
340-
if errors.As(err, &wtError) {
341-
if code, ok := wtError.ExitStatus(); ok {
342-
if code == 0 {
343-
return nil
344-
}
256+
if exitErr, ok := err.(*sys.ExitError); ok {
257+
if exitErr.ExitCode() == 0 {
258+
return nil
345259
}
346260
}
261+
347262
// Print WASM stdout
348-
stderrBlob, rferr := os.ReadFile(stderrPath)
349-
if rferr == nil && len(stderrBlob) > 0 {
350-
return errors.New(string(stderrBlob))
263+
stderrBlob := stderr.String()
264+
if len(stderrBlob) > 0 {
265+
return errors.New(stderrBlob)
351266
}
352267
return fmt.Errorf("call: %w", err)
353268
}

0 commit comments

Comments
 (0)