diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index f626dbfd14..db64152d00 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "os" "path/filepath" @@ -47,20 +48,26 @@ func cacheDir() (string, error) { var flight singleflight.Group // Verify the provided sha256 is valid. -func (r *Runner) parseChecksum() (string, error) { - if r.SHA256 == "" { - return "", fmt.Errorf("missing SHA-256 checksum") +func (r *Runner) getChecksum(ctx context.Context) (string, error) { + if r.SHA256 != "" { + return r.SHA256, nil } - return r.SHA256, nil + // TODO: Add a log line here about something + _, sum, err := r.fetch(ctx, r.URL) + if err != nil { + return "", err + } + slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum) + return sum, nil } func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) { - expected, err := r.parseChecksum() + expected, err := r.getChecksum(ctx) if err != nil { return nil, err } value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadSerializedModule(ctx, engine) + return r.loadSerializedModule(ctx, engine, expected) }) if err != nil { return nil, err @@ -72,17 +79,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm return wasmtime.NewModuleDeserialize(engine, data) } -func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) { - expected, err := r.parseChecksum() - if err != nil { - return nil, err - } +func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) { cacheDir, err := cache.PluginsDir() if err != nil { return nil, err } - pluginDir := filepath.Join(cacheDir, expected) + pluginDir := filepath.Join(cacheDir, expectedSha) modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion) modPath := filepath.Join(pluginDir, modName) _, staterr := os.Stat(modPath) @@ -94,7 +97,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi return data, nil } - wmod, err := r.loadWASM(ctx, cacheDir, expected) + wmod, err := r.loadWASM(ctx, cacheDir, expectedSha) if err != nil { return nil, err } @@ -121,53 +124,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi return out, nil } -func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { - pluginDir := filepath.Join(cache, expected) - pluginPath := filepath.Join(pluginDir, "plugin.wasm") - _, staterr := os.Stat(pluginPath) - +func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) { var body io.ReadCloser + switch { - case staterr == nil: - file, err := os.Open(pluginPath) - if err != nil { - return nil, fmt.Errorf("os.Open: %s %w", pluginPath, err) - } - body = file - case strings.HasPrefix(r.URL, "file://"): - file, err := os.Open(strings.TrimPrefix(r.URL, "file://")) + case strings.HasPrefix(uri, "file://"): + file, err := os.Open(strings.TrimPrefix(uri, "file://")) if err != nil { - return nil, fmt.Errorf("os.Open: %s %w", r.URL, err) + return nil, "", fmt.Errorf("os.Open: %s %w", uri, err) } body = file - case strings.HasPrefix(r.URL, "https://"): - req, err := http.NewRequestWithContext(ctx, "GET", r.URL, nil) + case strings.HasPrefix(uri, "https://"): + req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) if err != nil { - return nil, fmt.Errorf("http.Get: %s %w", r.URL, err) + return nil, "", fmt.Errorf("http.Get: %s %w", uri, err) } req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)) resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, fmt.Errorf("http.Get: %s %w", r.URL, err) + return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err) } body = resp.Body default: - return nil, fmt.Errorf("unknown scheme: %s", r.URL) + return nil, "", fmt.Errorf("unknown scheme: %s", r.URL) } defer body.Close() wmod, err := io.ReadAll(body) if err != nil { - return nil, fmt.Errorf("readall: %w", err) + return nil, "", fmt.Errorf("readall: %w", err) } sum := sha256.Sum256(wmod) actual := fmt.Sprintf("%x", sum) + return wmod, actual, nil +} + +func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { + pluginDir := filepath.Join(cache, expected) + pluginPath := filepath.Join(pluginDir, "plugin.wasm") + _, staterr := os.Stat(pluginPath) + + uri := r.URL + if staterr == nil { + uri = "file://" + pluginPath + } + + wmod, actual, err := r.fetch(ctx, uri) + if err != nil { + return nil, err + } + if expected != actual { return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual) }