Skip to content

feat(plugin): Calculate SHA256 if it does not exist #2935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 45 additions & 33 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -47,20 +48,26 @@ func cacheDir() (string, error) {
var flight singleflight.Group

// Verify the provided sha256 is valid.
func (r *Runner) parseChecksum() (string, error) {
if r.SHA256 == "" {
return "", fmt.Errorf("missing SHA-256 checksum")
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
if r.SHA256 != "" {
return r.SHA256, nil
}
return r.SHA256, nil
// TODO: Add a log line here about something
_, sum, err := r.fetch(ctx, r.URL)
if err != nil {
return "", err
}
slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum)
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
expected, err := r.parseChecksum()
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine)
return r.loadSerializedModule(ctx, engine, expected)
})
if err != nil {
return nil, err
Expand All @@ -72,17 +79,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
}
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expected)
pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
Expand All @@ -94,7 +97,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expected)
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
if err != nil {
return nil, err
}
Expand All @@ -121,53 +124,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
return out, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
pluginDir := filepath.Join(cache, expected)
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
_, staterr := os.Stat(pluginPath)

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
var body io.ReadCloser

switch {
case staterr == nil:
file, err := os.Open(pluginPath)
if err != nil {
return nil, fmt.Errorf("os.Open: %s %w", pluginPath, err)
}
body = file

case strings.HasPrefix(r.URL, "file://"):
file, err := os.Open(strings.TrimPrefix(r.URL, "file://"))
case strings.HasPrefix(uri, "file://"):
file, err := os.Open(strings.TrimPrefix(uri, "file://"))
if err != nil {
return nil, fmt.Errorf("os.Open: %s %w", r.URL, err)
return nil, "", fmt.Errorf("os.Open: %s %w", uri, err)
}
body = file

case strings.HasPrefix(r.URL, "https://"):
req, err := http.NewRequestWithContext(ctx, "GET", r.URL, nil)
case strings.HasPrefix(uri, "https://"):
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
return nil, "", fmt.Errorf("http.Get: %s %w", uri, err)
}
req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err)
}
body = resp.Body

default:
return nil, fmt.Errorf("unknown scheme: %s", r.URL)
return nil, "", fmt.Errorf("unknown scheme: %s", r.URL)
}

defer body.Close()

wmod, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("readall: %w", err)
return nil, "", fmt.Errorf("readall: %w", err)
}

sum := sha256.Sum256(wmod)
actual := fmt.Sprintf("%x", sum)

return wmod, actual, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
pluginDir := filepath.Join(cache, expected)
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
_, staterr := os.Stat(pluginPath)

uri := r.URL
if staterr == nil {
uri = "file://" + pluginPath
}

wmod, actual, err := r.fetch(ctx, uri)
if err != nil {
return nil, err
}

if expected != actual {
return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual)
}
Expand Down