Skip to content

Commit f80cee1

Browse files
authored
feat(plugin): Calculate SHA256 if it does not exist (#2935)
* feat(plugin): Calculate SHA256 if it does not exist * Add logging
1 parent 4507ede commit f80cee1

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

internal/ext/wasm/wasm.go

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13+
"log/slog"
1314
"net/http"
1415
"os"
1516
"path/filepath"
@@ -52,20 +53,26 @@ func cacheDir() (string, error) {
5253
var flight singleflight.Group
5354

5455
// Verify the provided sha256 is valid.
55-
func (r *Runner) parseChecksum() (string, error) {
56-
if r.SHA256 == "" {
57-
return "", fmt.Errorf("missing SHA-256 checksum")
56+
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
57+
if r.SHA256 != "" {
58+
return r.SHA256, nil
5859
}
59-
return r.SHA256, nil
60+
// TODO: Add a log line here about something
61+
_, sum, err := r.fetch(ctx, r.URL)
62+
if err != nil {
63+
return "", err
64+
}
65+
slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum)
66+
return sum, nil
6067
}
6168

6269
func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
63-
expected, err := r.parseChecksum()
70+
expected, err := r.getChecksum(ctx)
6471
if err != nil {
6572
return nil, err
6673
}
6774
value, err, _ := flight.Do(expected, func() (interface{}, error) {
68-
return r.loadSerializedModule(ctx, engine)
75+
return r.loadSerializedModule(ctx, engine, expected)
6976
})
7077
if err != nil {
7178
return nil, err
@@ -77,17 +84,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
7784
return wasmtime.NewModuleDeserialize(engine, data)
7885
}
7986

80-
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
81-
expected, err := r.parseChecksum()
82-
if err != nil {
83-
return nil, err
84-
}
87+
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
8588
cacheDir, err := cache.PluginsDir()
8689
if err != nil {
8790
return nil, err
8891
}
8992

90-
pluginDir := filepath.Join(cacheDir, expected)
93+
pluginDir := filepath.Join(cacheDir, expectedSha)
9194
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
9295
modPath := filepath.Join(pluginDir, modName)
9396
_, staterr := os.Stat(modPath)
@@ -99,7 +102,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
99102
return data, nil
100103
}
101104

102-
wmod, err := r.loadWASM(ctx, cacheDir, expected)
105+
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
103106
if err != nil {
104107
return nil, err
105108
}
@@ -126,53 +129,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
126129
return out, nil
127130
}
128131

129-
func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
130-
pluginDir := filepath.Join(cache, expected)
131-
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
132-
_, staterr := os.Stat(pluginPath)
133-
132+
func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
134133
var body io.ReadCloser
134+
135135
switch {
136-
case staterr == nil:
137-
file, err := os.Open(pluginPath)
138-
if err != nil {
139-
return nil, fmt.Errorf("os.Open: %s %w", pluginPath, err)
140-
}
141-
body = file
142136

143-
case strings.HasPrefix(r.URL, "file://"):
144-
file, err := os.Open(strings.TrimPrefix(r.URL, "file://"))
137+
case strings.HasPrefix(uri, "file://"):
138+
file, err := os.Open(strings.TrimPrefix(uri, "file://"))
145139
if err != nil {
146-
return nil, fmt.Errorf("os.Open: %s %w", r.URL, err)
140+
return nil, "", fmt.Errorf("os.Open: %s %w", uri, err)
147141
}
148142
body = file
149143

150-
case strings.HasPrefix(r.URL, "https://"):
151-
req, err := http.NewRequestWithContext(ctx, "GET", r.URL, nil)
144+
case strings.HasPrefix(uri, "https://"):
145+
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
152146
if err != nil {
153-
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
147+
return nil, "", fmt.Errorf("http.Get: %s %w", uri, err)
154148
}
155149
req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH))
156150
resp, err := http.DefaultClient.Do(req)
157151
if err != nil {
158-
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
152+
return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err)
159153
}
160154
body = resp.Body
161155

162156
default:
163-
return nil, fmt.Errorf("unknown scheme: %s", r.URL)
157+
return nil, "", fmt.Errorf("unknown scheme: %s", r.URL)
164158
}
165159

166160
defer body.Close()
167161

168162
wmod, err := io.ReadAll(body)
169163
if err != nil {
170-
return nil, fmt.Errorf("readall: %w", err)
164+
return nil, "", fmt.Errorf("readall: %w", err)
171165
}
172166

173167
sum := sha256.Sum256(wmod)
174168
actual := fmt.Sprintf("%x", sum)
175169

170+
return wmod, actual, nil
171+
}
172+
173+
func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
174+
pluginDir := filepath.Join(cache, expected)
175+
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
176+
_, staterr := os.Stat(pluginPath)
177+
178+
uri := r.URL
179+
if staterr == nil {
180+
uri = "file://" + pluginPath
181+
}
182+
183+
wmod, actual, err := r.fetch(ctx, uri)
184+
if err != nil {
185+
return nil, err
186+
}
187+
176188
if expected != actual {
177189
return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual)
178190
}

0 commit comments

Comments
 (0)