5
5
package wasm
6
6
7
7
import (
8
+ "bytes"
8
9
"context"
9
10
"crypto/sha256"
10
11
"errors"
@@ -15,10 +16,11 @@ import (
15
16
"os"
16
17
"path/filepath"
17
18
"runtime"
18
- "runtime/trace"
19
19
"strings"
20
20
21
- wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
21
+ "github.com/tetratelabs/wazero"
22
+ "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
23
+ "github.com/tetratelabs/wazero/sys"
22
24
"golang.org/x/sync/singleflight"
23
25
"google.golang.org/grpc"
24
26
"google.golang.org/grpc/codes"
@@ -70,13 +72,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
70
72
return sum , nil
71
73
}
72
74
73
- func (r * Runner ) loadModule (ctx context.Context , engine * wasmtime. Engine ) (* wasmtime. Module , error ) {
75
+ func (r * Runner ) loadBytes (ctx context.Context ) ([] byte , error ) {
74
76
expected , err := r .getChecksum (ctx )
75
77
if err != nil {
76
78
return nil , err
77
79
}
80
+ cacheDir , err := cache .PluginsDir ()
81
+ if err != nil {
82
+ return nil , err
83
+ }
78
84
value , err , _ := flight .Do (expected , func () (interface {}, error ) {
79
- return r .loadSerializedModule (ctx , engine , expected )
85
+ return r .loadWASM (ctx , cacheDir , expected )
80
86
})
81
87
if err != nil {
82
88
return nil , err
@@ -85,52 +91,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
85
91
if ! ok {
86
92
return nil , fmt .Errorf ("returned value was not a byte slice" )
87
93
}
88
- return wasmtime .NewModuleDeserialize (engine , data )
89
- }
90
-
91
- func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine , expectedSha string ) ([]byte , error ) {
92
- cacheDir , err := cache .PluginsDir ()
93
- if err != nil {
94
- return nil , err
95
- }
96
-
97
- pluginDir := filepath .Join (cacheDir , expectedSha )
98
- modName := fmt .Sprintf ("plugin_%s_%s_%s.module" , runtime .GOOS , runtime .GOARCH , wasmtimeVersion )
99
- modPath := filepath .Join (pluginDir , modName )
100
- _ , staterr := os .Stat (modPath )
101
- if staterr == nil {
102
- data , err := os .ReadFile (modPath )
103
- if err != nil {
104
- return nil , err
105
- }
106
- return data , nil
107
- }
108
-
109
- wmod , err := r .loadWASM (ctx , cacheDir , expectedSha )
110
- if err != nil {
111
- return nil , err
112
- }
113
-
114
- moduRegion := trace .StartRegion (ctx , "wasmtime.NewModule" )
115
- module , err := wasmtime .NewModule (engine , wmod )
116
- moduRegion .End ()
117
- if err != nil {
118
- return nil , fmt .Errorf ("define wasi: %w" , err )
119
- }
120
-
121
- err = os .Mkdir (pluginDir , 0755 )
122
- if err != nil && ! os .IsExist (err ) {
123
- return nil , fmt .Errorf ("mkdirall: %w" , err )
124
- }
125
- out , err := module .Serialize ()
126
- if err != nil {
127
- return nil , fmt .Errorf ("serialize: %w" , err )
128
- }
129
- if err := os .WriteFile (modPath , out , 0444 ); err != nil {
130
- return nil , fmt .Errorf ("cache wasm: %w" , err )
131
- }
132
-
133
- return out , nil
94
+ return data , nil
134
95
}
135
96
136
97
func (r * Runner ) fetch (ctx context.Context , uri string ) ([]byte , string , error ) {
@@ -245,72 +206,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
245
206
return fmt .Errorf ("failed to encode codegen request: %w" , err )
246
207
}
247
208
248
- engine := wasmtime .NewEngine ()
249
- module , err := r .loadModule (ctx , engine )
209
+ cacheDir , err := cache .PluginsDir ()
250
210
if err != nil {
251
- return fmt . Errorf ( "loadModule: %w" , err )
211
+ return err
252
212
}
253
213
254
- linker := wasmtime . NewLinker ( engine )
255
- if err := linker . DefineWasi (); err != nil {
214
+ cache , err := wazero . NewCompilationCacheWithDir ( filepath . Join ( cacheDir , "wazero" ) )
215
+ if err != nil {
256
216
return err
257
217
}
258
218
259
- dir , err := os . MkdirTemp ( os . Getenv ( "SQLCTMPDIR" ), "out" )
219
+ wasmBytes , err := r . loadBytes ( ctx )
260
220
if err != nil {
261
- return fmt .Errorf ("temp dir : %w" , err )
221
+ return fmt .Errorf ("loadModule : %w" , err )
262
222
}
263
223
264
- defer os .RemoveAll (dir )
265
- stdinPath := filepath .Join (dir , "stdin" )
266
- stderrPath := filepath .Join (dir , "stderr" )
267
- stdoutPath := filepath .Join (dir , "stdout" )
224
+ config := wazero .NewRuntimeConfig ().WithCompilationCache (cache )
225
+ rt := wazero .NewRuntimeWithConfig (ctx , config )
226
+ defer rt .Close (ctx )
268
227
269
- if err := os .WriteFile (stdinPath , stdinBlob , 0755 ); err != nil {
270
- return fmt .Errorf ("write file: %w" , err )
271
- }
272
-
273
- // Configure WASI imports to write stdout into a file.
274
- wasiConfig := wasmtime .NewWasiConfig ()
275
- wasiConfig .SetArgv ([]string {"plugin.wasm" , method })
276
- wasiConfig .SetStdinFile (stdinPath )
277
- wasiConfig .SetStdoutFile (stdoutPath )
278
- wasiConfig .SetStderrFile (stderrPath )
228
+ // TODO: Handle error
229
+ wasi_snapshot_preview1 .MustInstantiate (ctx , rt )
279
230
280
- keys := []string {"SQLC_VERSION" }
281
- vals := []string {info .Version }
282
- for _ , key := range r .Env {
283
- keys = append (keys , key )
284
- vals = append (vals , os .Getenv (key ))
231
+ // Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
232
+ mod , err := rt .CompileModule (ctx , wasmBytes )
233
+ if err != nil {
234
+ return err
285
235
}
286
- wasiConfig .SetEnv (keys , vals )
287
236
288
- store := wasmtime .NewStore (engine )
289
- store .SetWasi (wasiConfig )
237
+ var stderr , stdout bytes.Buffer
290
238
291
- linkRegion := trace . StartRegion ( ctx , "linker.DefineModule" )
292
- err = linker . DefineModule ( store , " " , module )
293
- linkRegion . End ( )
294
- if err != nil {
295
- return fmt . Errorf ( "define wasi: %w" , err )
239
+ conf := wazero . NewModuleConfig ( )
240
+ conf = conf . WithArgs ( "plugin.wasm " , method )
241
+ conf = conf . WithEnv ( "SQLC_VERSION" , info . Version )
242
+ for _ , key := range r . Env {
243
+ conf = conf . WithEnv ( key , os . Getenv ( key ) )
296
244
}
245
+ conf = conf .WithStdin (bytes .NewReader (stdinBlob ))
246
+ conf = conf .WithStdout (& stdout )
247
+ conf = conf .WithStderr (& stderr )
297
248
298
- // Run the function
299
- fn , err := linker .GetDefault (store , "" )
300
- if err != nil {
301
- return fmt .Errorf ("wasi: get default: %w" , err )
249
+ result , err := rt .InstantiateModule (ctx , mod , conf )
250
+ if result != nil {
251
+ defer result .Close (ctx )
302
252
}
303
-
304
- callRegion := trace .StartRegion (ctx , "call _start" )
305
- _ , err = fn .Call (store )
306
- callRegion .End ()
307
-
308
- if cerr := checkError (err , stderrPath ); cerr != nil {
253
+ if cerr := checkError (err , & stderr ); cerr != nil {
309
254
return cerr
310
255
}
311
256
312
257
// Print WASM stdout
313
- stdoutBlob , err := os . ReadFile ( stdoutPath )
258
+ stdoutBlob , err := io . ReadAll ( & stdout )
314
259
if err != nil {
315
260
return fmt .Errorf ("read file: %w" , err )
316
261
}
@@ -331,21 +276,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
331
276
return nil , status .Error (codes .Unimplemented , "" )
332
277
}
333
278
334
- func checkError (err error , stderrPath string ) error {
279
+ func checkError (err error , stderr io. Reader ) error {
335
280
if err == nil {
336
281
return err
337
282
}
338
283
339
- var wtError * wasmtime.Error
340
- if errors .As (err , & wtError ) {
341
- if code , ok := wtError .ExitStatus (); ok {
342
- if code == 0 {
343
- return nil
344
- }
284
+ if exitErr , ok := err .(* sys.ExitError ); ok {
285
+ if exitErr .ExitCode () == 0 {
286
+ return nil
345
287
}
346
288
}
289
+
347
290
// Print WASM stdout
348
- stderrBlob , rferr := os . ReadFile ( stderrPath )
291
+ stderrBlob , rferr := io . ReadAll ( stderr )
349
292
if rferr == nil && len (stderrBlob ) > 0 {
350
293
return errors .New (string (stderrBlob ))
351
294
}
0 commit comments