Skip to content

Commit 245563d

Browse files
authored
Suggestions for PR #3042 (#3082)
* Only compile wasm once per process * Remove unused * Store runtime in flightgroup as well
1 parent 0914cad commit 245563d

File tree

1 file changed

+36
-54
lines changed

1 file changed

+36
-54
lines changed

internal/ext/wasm/wasm.go

Lines changed: 36 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,13 @@ import (
2929
"github.com/sqlc-dev/sqlc/internal/plugin"
3030
)
3131

32-
func cacheDir() (string, error) {
33-
cache := os.Getenv("SQLCCACHE")
34-
if cache != "" {
35-
return cache, nil
36-
}
37-
cacheHome := os.Getenv("XDG_CACHE_HOME")
38-
if cacheHome == "" {
39-
home, err := os.UserHomeDir()
40-
if err != nil {
41-
return "", err
42-
}
43-
cacheHome = filepath.Join(home, ".cache")
44-
}
45-
return filepath.Join(cacheHome, "sqlc"), nil
46-
}
47-
4832
var flight singleflight.Group
4933

34+
type runtimeAndCode struct {
35+
rt wazero.Runtime
36+
code wazero.CompiledModule
37+
}
38+
5039
// Verify the provided sha256 is valid.
5140
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
5241
if r.SHA256 != "" {
@@ -61,7 +50,7 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
6150
return sum, nil
6251
}
6352

64-
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
53+
func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) {
6554
expected, err := r.getChecksum(ctx)
6655
if err != nil {
6756
return nil, err
@@ -71,14 +60,14 @@ func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) {
7160
return nil, err
7261
}
7362
value, err, _ := flight.Do(expected, func() (interface{}, error) {
74-
return r.loadWASM(ctx, cacheDir, expected)
63+
return r.loadAndCompileWASM(ctx, cacheDir, expected)
7564
})
7665
if err != nil {
7766
return nil, err
7867
}
79-
data, ok := value.([]byte)
68+
data, ok := value.(*runtimeAndCode)
8069
if !ok {
81-
return nil, fmt.Errorf("returned value was not a byte slice")
70+
return nil, fmt.Errorf("returned value was not a compiled module")
8271
}
8372
return data, nil
8473
}
@@ -124,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error)
124113
return wmod, actual, nil
125114
}
126115

127-
func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
116+
func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) {
128117
pluginDir := filepath.Join(cache, expected)
129118
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
130119
_, staterr := os.Stat(pluginPath)
@@ -153,7 +142,22 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([
153142
}
154143
}
155144

156-
return wmod, nil
145+
wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero"))
146+
if err != nil {
147+
return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err)
148+
}
149+
config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache)
150+
rt := wazero.NewRuntimeWithConfig(ctx, config)
151+
// TODO: Handle error
152+
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
153+
154+
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
155+
code, err := rt.CompileModule(ctx, wmod)
156+
if err != nil {
157+
return nil, fmt.Errorf("compile module: %w", err)
158+
}
159+
160+
return &runtimeAndCode{rt: rt, code: code}, nil
157161
}
158162

159163
// removePGCatalog removes the pg_catalog schema from the request. There is a
@@ -195,47 +199,25 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
195199
return fmt.Errorf("failed to encode codegen request: %w", err)
196200
}
197201

198-
cacheDir, err := cache.PluginsDir()
202+
runtimeAndCode, err := r.loadAndCompile(ctx)
199203
if err != nil {
200-
return err
201-
}
202-
203-
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero"))
204-
if err != nil {
205-
return err
206-
}
207-
208-
wasmBytes, err := r.loadBytes(ctx)
209-
if err != nil {
210-
return fmt.Errorf("loadModule: %w", err)
211-
}
212-
213-
config := wazero.NewRuntimeConfig().WithCompilationCache(cache)
214-
rt := wazero.NewRuntimeWithConfig(ctx, config)
215-
defer rt.Close(ctx)
216-
217-
// TODO: Handle error
218-
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
219-
220-
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation.
221-
mod, err := rt.CompileModule(ctx, wasmBytes)
222-
if err != nil {
223-
return err
204+
return fmt.Errorf("loadBytes: %w", err)
224205
}
225206

226207
var stderr, stdout bytes.Buffer
227208

228-
conf := wazero.NewModuleConfig()
229-
conf = conf.WithArgs("plugin.wasm", method)
230-
conf = conf.WithEnv("SQLC_VERSION", info.Version)
209+
conf := wazero.NewModuleConfig().
210+
WithName("").
211+
WithArgs("plugin.wasm", method).
212+
WithStdin(bytes.NewReader(stdinBlob)).
213+
WithStdout(&stdout).
214+
WithStderr(&stderr).
215+
WithEnv("SQLC_VERSION", info.Version)
231216
for _, key := range r.Env {
232217
conf = conf.WithEnv(key, os.Getenv(key))
233218
}
234-
conf = conf.WithStdin(bytes.NewReader(stdinBlob))
235-
conf = conf.WithStdout(&stdout)
236-
conf = conf.WithStderr(&stderr)
237219

238-
result, err := rt.InstantiateModule(ctx, mod, conf)
220+
result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf)
239221
if result != nil {
240222
defer result.Close(ctx)
241223
}

0 commit comments

Comments
 (0)