@@ -10,6 +10,7 @@ import (
10
10
"errors"
11
11
"fmt"
12
12
"io"
13
+ "log/slog"
13
14
"net/http"
14
15
"os"
15
16
"path/filepath"
@@ -52,20 +53,26 @@ func cacheDir() (string, error) {
52
53
var flight singleflight.Group
53
54
54
55
// Verify the provided sha256 is valid.
55
- func (r * Runner ) parseChecksum ( ) (string , error ) {
56
- if r .SHA256 = = "" {
57
- return "" , fmt . Errorf ( "missing SHA-256 checksum" )
56
+ func (r * Runner ) getChecksum ( ctx context. Context ) (string , error ) {
57
+ if r .SHA256 ! = "" {
58
+ return r . SHA256 , nil
58
59
}
59
- return r .SHA256 , nil
60
+ // TODO: Add a log line here about something
61
+ _ , sum , err := r .fetch (ctx , r .URL )
62
+ if err != nil {
63
+ return "" , err
64
+ }
65
+ slog .Warn ("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work" , "sha256" , sum )
66
+ return sum , nil
60
67
}
61
68
62
69
func (r * Runner ) loadModule (ctx context.Context , engine * wasmtime.Engine ) (* wasmtime.Module , error ) {
63
- expected , err := r .parseChecksum ( )
70
+ expected , err := r .getChecksum ( ctx )
64
71
if err != nil {
65
72
return nil , err
66
73
}
67
74
value , err , _ := flight .Do (expected , func () (interface {}, error ) {
68
- return r .loadSerializedModule (ctx , engine )
75
+ return r .loadSerializedModule (ctx , engine , expected )
69
76
})
70
77
if err != nil {
71
78
return nil , err
@@ -77,17 +84,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
77
84
return wasmtime .NewModuleDeserialize (engine , data )
78
85
}
79
86
80
- func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine ) ([]byte , error ) {
81
- expected , err := r .parseChecksum ()
82
- if err != nil {
83
- return nil , err
84
- }
87
+ func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine , expectedSha string ) ([]byte , error ) {
85
88
cacheDir , err := cache .PluginsDir ()
86
89
if err != nil {
87
90
return nil , err
88
91
}
89
92
90
- pluginDir := filepath .Join (cacheDir , expected )
93
+ pluginDir := filepath .Join (cacheDir , expectedSha )
91
94
modName := fmt .Sprintf ("plugin_%s_%s_%s.module" , runtime .GOOS , runtime .GOARCH , wasmtimeVersion )
92
95
modPath := filepath .Join (pluginDir , modName )
93
96
_ , staterr := os .Stat (modPath )
@@ -99,7 +102,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
99
102
return data , nil
100
103
}
101
104
102
- wmod , err := r .loadWASM (ctx , cacheDir , expected )
105
+ wmod , err := r .loadWASM (ctx , cacheDir , expectedSha )
103
106
if err != nil {
104
107
return nil , err
105
108
}
@@ -126,53 +129,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
126
129
return out , nil
127
130
}
128
131
129
- func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([]byte , error ) {
130
- pluginDir := filepath .Join (cache , expected )
131
- pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
132
- _ , staterr := os .Stat (pluginPath )
133
-
132
+ func (r * Runner ) fetch (ctx context.Context , uri string ) ([]byte , string , error ) {
134
133
var body io.ReadCloser
134
+
135
135
switch {
136
- case staterr == nil :
137
- file , err := os .Open (pluginPath )
138
- if err != nil {
139
- return nil , fmt .Errorf ("os.Open: %s %w" , pluginPath , err )
140
- }
141
- body = file
142
136
143
- case strings .HasPrefix (r . URL , "file://" ):
144
- file , err := os .Open (strings .TrimPrefix (r . URL , "file://" ))
137
+ case strings .HasPrefix (uri , "file://" ):
138
+ file , err := os .Open (strings .TrimPrefix (uri , "file://" ))
145
139
if err != nil {
146
- return nil , fmt .Errorf ("os.Open: %s %w" , r . URL , err )
140
+ return nil , "" , fmt .Errorf ("os.Open: %s %w" , uri , err )
147
141
}
148
142
body = file
149
143
150
- case strings .HasPrefix (r . URL , "https://" ):
151
- req , err := http .NewRequestWithContext (ctx , "GET" , r . URL , nil )
144
+ case strings .HasPrefix (uri , "https://" ):
145
+ req , err := http .NewRequestWithContext (ctx , "GET" , uri , nil )
152
146
if err != nil {
153
- return nil , fmt .Errorf ("http.Get: %s %w" , r . URL , err )
147
+ return nil , "" , fmt .Errorf ("http.Get: %s %w" , uri , err )
154
148
}
155
149
req .Header .Set ("User-Agent" , fmt .Sprintf ("sqlc/%s Go/%s (%s %s)" , info .Version , runtime .Version (), runtime .GOOS , runtime .GOARCH ))
156
150
resp , err := http .DefaultClient .Do (req )
157
151
if err != nil {
158
- return nil , fmt .Errorf ("http.Get: %s %w" , r .URL , err )
152
+ return nil , "" , fmt .Errorf ("http.Get: %s %w" , r .URL , err )
159
153
}
160
154
body = resp .Body
161
155
162
156
default :
163
- return nil , fmt .Errorf ("unknown scheme: %s" , r .URL )
157
+ return nil , "" , fmt .Errorf ("unknown scheme: %s" , r .URL )
164
158
}
165
159
166
160
defer body .Close ()
167
161
168
162
wmod , err := io .ReadAll (body )
169
163
if err != nil {
170
- return nil , fmt .Errorf ("readall: %w" , err )
164
+ return nil , "" , fmt .Errorf ("readall: %w" , err )
171
165
}
172
166
173
167
sum := sha256 .Sum256 (wmod )
174
168
actual := fmt .Sprintf ("%x" , sum )
175
169
170
+ return wmod , actual , nil
171
+ }
172
+
173
+ func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([]byte , error ) {
174
+ pluginDir := filepath .Join (cache , expected )
175
+ pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
176
+ _ , staterr := os .Stat (pluginPath )
177
+
178
+ uri := r .URL
179
+ if staterr == nil {
180
+ uri = "file://" + pluginPath
181
+ }
182
+
183
+ wmod , actual , err := r .fetch (ctx , uri )
184
+ if err != nil {
185
+ return nil , err
186
+ }
187
+
176
188
if expected != actual {
177
189
return nil , fmt .Errorf ("invalid checksum: expected %s, got %s" , expected , actual )
178
190
}
0 commit comments