1
- //go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))
2
-
3
- // The above build constraint is based of the cgo directives in this file:
4
- // https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
5
1
package wasm
6
2
7
3
import (
4
+ "bytes"
8
5
"context"
9
6
"crypto/sha256"
10
7
"errors"
@@ -15,10 +12,11 @@ import (
15
12
"os"
16
13
"path/filepath"
17
14
"runtime"
18
- "runtime/trace"
19
15
"strings"
20
16
21
- wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
17
+ "github.com/tetratelabs/wazero"
18
+ "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
19
+ "github.com/tetratelabs/wazero/sys"
22
20
"golang.org/x/sync/singleflight"
23
21
"google.golang.org/grpc"
24
22
"google.golang.org/grpc/codes"
@@ -31,31 +29,13 @@ import (
31
29
"github.com/sqlc-dev/sqlc/internal/plugin"
32
30
)
33
31
34
- func Enabled () bool {
35
- return true
36
- }
37
-
38
- // This version must be updated whenever the wasmtime-go dependency is updated
39
- const wasmtimeVersion = `v14.0.0`
32
+ var flight singleflight.Group
40
33
41
- func cacheDir () (string , error ) {
42
- cache := os .Getenv ("SQLCCACHE" )
43
- if cache != "" {
44
- return cache , nil
45
- }
46
- cacheHome := os .Getenv ("XDG_CACHE_HOME" )
47
- if cacheHome == "" {
48
- home , err := os .UserHomeDir ()
49
- if err != nil {
50
- return "" , err
51
- }
52
- cacheHome = filepath .Join (home , ".cache" )
53
- }
54
- return filepath .Join (cacheHome , "sqlc" ), nil
34
+ type runtimeAndCode struct {
35
+ rt wazero.Runtime
36
+ code wazero.CompiledModule
55
37
}
56
38
57
- var flight singleflight.Group
58
-
59
39
// Verify the provided sha256 is valid.
60
40
func (r * Runner ) getChecksum (ctx context.Context ) (string , error ) {
61
41
if r .SHA256 != "" {
@@ -70,67 +50,26 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
70
50
return sum , nil
71
51
}
72
52
73
- func (r * Runner ) loadModule (ctx context.Context , engine * wasmtime. Engine ) (* wasmtime. Module , error ) {
53
+ func (r * Runner ) loadAndCompile (ctx context.Context ) (* runtimeAndCode , error ) {
74
54
expected , err := r .getChecksum (ctx )
75
55
if err != nil {
76
56
return nil , err
77
57
}
78
- value , err , _ := flight .Do (expected , func () (interface {}, error ) {
79
- return r .loadSerializedModule (ctx , engine , expected )
80
- })
81
- if err != nil {
82
- return nil , err
83
- }
84
- data , ok := value .([]byte )
85
- if ! ok {
86
- return nil , fmt .Errorf ("returned value was not a byte slice" )
87
- }
88
- return wasmtime .NewModuleDeserialize (engine , data )
89
- }
90
-
91
- func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine , expectedSha string ) ([]byte , error ) {
92
58
cacheDir , err := cache .PluginsDir ()
93
59
if err != nil {
94
60
return nil , err
95
61
}
96
-
97
- pluginDir := filepath .Join (cacheDir , expectedSha )
98
- modName := fmt .Sprintf ("plugin_%s_%s_%s.module" , runtime .GOOS , runtime .GOARCH , wasmtimeVersion )
99
- modPath := filepath .Join (pluginDir , modName )
100
- _ , staterr := os .Stat (modPath )
101
- if staterr == nil {
102
- data , err := os .ReadFile (modPath )
103
- if err != nil {
104
- return nil , err
105
- }
106
- return data , nil
107
- }
108
-
109
- wmod , err := r .loadWASM (ctx , cacheDir , expectedSha )
62
+ value , err , _ := flight .Do (expected , func () (interface {}, error ) {
63
+ return r .loadAndCompileWASM (ctx , cacheDir , expected )
64
+ })
110
65
if err != nil {
111
66
return nil , err
112
67
}
113
-
114
- moduRegion := trace .StartRegion (ctx , "wasmtime.NewModule" )
115
- module , err := wasmtime .NewModule (engine , wmod )
116
- moduRegion .End ()
117
- if err != nil {
118
- return nil , fmt .Errorf ("define wasi: %w" , err )
119
- }
120
-
121
- err = os .Mkdir (pluginDir , 0755 )
122
- if err != nil && ! os .IsExist (err ) {
123
- return nil , fmt .Errorf ("mkdirall: %w" , err )
124
- }
125
- out , err := module .Serialize ()
126
- if err != nil {
127
- return nil , fmt .Errorf ("serialize: %w" , err )
128
- }
129
- if err := os .WriteFile (modPath , out , 0444 ); err != nil {
130
- return nil , fmt .Errorf ("cache wasm: %w" , err )
68
+ data , ok := value .(* runtimeAndCode )
69
+ if ! ok {
70
+ return nil , fmt .Errorf ("returned value was not a compiled module" )
131
71
}
132
-
133
- return out , nil
72
+ return data , nil
134
73
}
135
74
136
75
func (r * Runner ) fetch (ctx context.Context , uri string ) ([]byte , string , error ) {
@@ -174,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error)
174
113
return wmod , actual , nil
175
114
}
176
115
177
- func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([] byte , error ) {
116
+ func (r * Runner ) loadAndCompileWASM (ctx context.Context , cache string , expected string ) (* runtimeAndCode , error ) {
178
117
pluginDir := filepath .Join (cache , expected )
179
118
pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
180
119
_ , staterr := os .Stat (pluginPath )
@@ -203,7 +142,26 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([
203
142
}
204
143
}
205
144
206
- return wmod , nil
145
+ wazeroCache , err := wazero .NewCompilationCacheWithDir (filepath .Join (cache , "wazero" ))
146
+ if err != nil {
147
+ return nil , fmt .Errorf ("wazero.NewCompilationCacheWithDir: %w" , err )
148
+ }
149
+
150
+ config := wazero .NewRuntimeConfig ().WithCompilationCache (wazeroCache )
151
+ rt := wazero .NewRuntimeWithConfig (ctx , config )
152
+
153
+ if _ , err := wasi_snapshot_preview1 .Instantiate (ctx , rt ); err != nil {
154
+ return nil , fmt .Errorf ("wasi_snapshot_preview1 instantiate: %w" , err )
155
+ }
156
+
157
+ // Compile the Wasm binary once so that we can skip the entire compilation
158
+ // time during instantiation.
159
+ code , err := rt .CompileModule (ctx , wmod )
160
+ if err != nil {
161
+ return nil , fmt .Errorf ("compile module: %w" , err )
162
+ }
163
+
164
+ return & runtimeAndCode {rt : rt , code : code }, nil
207
165
}
208
166
209
167
// removePGCatalog removes the pg_catalog schema from the request. There is a
@@ -245,75 +203,34 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
245
203
return fmt .Errorf ("failed to encode codegen request: %w" , err )
246
204
}
247
205
248
- engine := wasmtime .NewEngine ()
249
- module , err := r .loadModule (ctx , engine )
250
- if err != nil {
251
- return fmt .Errorf ("loadModule: %w" , err )
252
- }
253
-
254
- linker := wasmtime .NewLinker (engine )
255
- if err := linker .DefineWasi (); err != nil {
256
- return err
257
- }
258
-
259
- dir , err := os .MkdirTemp (os .Getenv ("SQLCTMPDIR" ), "out" )
206
+ runtimeAndCode , err := r .loadAndCompile (ctx )
260
207
if err != nil {
261
- return fmt .Errorf ("temp dir: %w" , err )
262
- }
263
-
264
- defer os .RemoveAll (dir )
265
- stdinPath := filepath .Join (dir , "stdin" )
266
- stderrPath := filepath .Join (dir , "stderr" )
267
- stdoutPath := filepath .Join (dir , "stdout" )
268
-
269
- if err := os .WriteFile (stdinPath , stdinBlob , 0755 ); err != nil {
270
- return fmt .Errorf ("write file: %w" , err )
208
+ return fmt .Errorf ("loadBytes: %w" , err )
271
209
}
272
210
273
- // Configure WASI imports to write stdout into a file.
274
- wasiConfig := wasmtime .NewWasiConfig ()
275
- wasiConfig .SetArgv ([]string {"plugin.wasm" , method })
276
- wasiConfig .SetStdinFile (stdinPath )
277
- wasiConfig .SetStdoutFile (stdoutPath )
278
- wasiConfig .SetStderrFile (stderrPath )
211
+ var stderr , stdout bytes.Buffer
279
212
280
- keys := []string {"SQLC_VERSION" }
281
- vals := []string {info .Version }
213
+ conf := wazero .NewModuleConfig ().
214
+ WithName ("" ).
215
+ WithArgs ("plugin.wasm" , method ).
216
+ WithStdin (bytes .NewReader (stdinBlob )).
217
+ WithStdout (& stdout ).
218
+ WithStderr (& stderr ).
219
+ WithEnv ("SQLC_VERSION" , info .Version )
282
220
for _ , key := range r .Env {
283
- keys = append (keys , key )
284
- vals = append (vals , os .Getenv (key ))
285
- }
286
- wasiConfig .SetEnv (keys , vals )
287
-
288
- store := wasmtime .NewStore (engine )
289
- store .SetWasi (wasiConfig )
290
-
291
- linkRegion := trace .StartRegion (ctx , "linker.DefineModule" )
292
- err = linker .DefineModule (store , "" , module )
293
- linkRegion .End ()
294
- if err != nil {
295
- return fmt .Errorf ("define wasi: %w" , err )
221
+ conf = conf .WithEnv (key , os .Getenv (key ))
296
222
}
297
223
298
- // Run the function
299
- fn , err := linker .GetDefault (store , "" )
300
- if err != nil {
301
- return fmt .Errorf ("wasi: get default: %w" , err )
224
+ result , err := runtimeAndCode .rt .InstantiateModule (ctx , runtimeAndCode .code , conf )
225
+ if result != nil {
226
+ defer result .Close (ctx )
302
227
}
303
-
304
- callRegion := trace .StartRegion (ctx , "call _start" )
305
- _ , err = fn .Call (store )
306
- callRegion .End ()
307
-
308
- if cerr := checkError (err , stderrPath ); cerr != nil {
228
+ if cerr := checkError (err , stderr ); cerr != nil {
309
229
return cerr
310
230
}
311
231
312
232
// Print WASM stdout
313
- stdoutBlob , err := os .ReadFile (stdoutPath )
314
- if err != nil {
315
- return fmt .Errorf ("read file: %w" , err )
316
- }
233
+ stdoutBlob := stdout .Bytes ()
317
234
318
235
resp , ok := reply .(protoreflect.ProtoMessage )
319
236
if ! ok {
@@ -331,23 +248,21 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
331
248
return nil , status .Error (codes .Unimplemented , "" )
332
249
}
333
250
334
- func checkError (err error , stderrPath string ) error {
251
+ func checkError (err error , stderr bytes. Buffer ) error {
335
252
if err == nil {
336
253
return err
337
254
}
338
255
339
- var wtError * wasmtime.Error
340
- if errors .As (err , & wtError ) {
341
- if code , ok := wtError .ExitStatus (); ok {
342
- if code == 0 {
343
- return nil
344
- }
256
+ if exitErr , ok := err .(* sys.ExitError ); ok {
257
+ if exitErr .ExitCode () == 0 {
258
+ return nil
345
259
}
346
260
}
261
+
347
262
// Print WASM stdout
348
- stderrBlob , rferr := os . ReadFile ( stderrPath )
349
- if rferr == nil && len (stderrBlob ) > 0 {
350
- return errors .New (string ( stderrBlob ) )
263
+ stderrBlob := stderr . String ( )
264
+ if len (stderrBlob ) > 0 {
265
+ return errors .New (stderrBlob )
351
266
}
352
267
return fmt .Errorf ("call: %w" , err )
353
268
}
0 commit comments