Skip to content

Commit d64a68b

Browse files
authored
cmd: Generate packages in parallel (#2026)
* compiler: Speed up generate * wasm: Load serialized modules using singleflight
1 parent c4ceb0e commit d64a68b

File tree

5 files changed

+118
-70
lines changed

5 files changed

+118
-70
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ require (
1717
github.com/pganalyze/pg_query_go/v2 v2.2.0
1818
github.com/spf13/cobra v1.6.1
1919
github.com/spf13/pflag v1.0.5
20+
golang.org/x/sync v0.1.0
2021
google.golang.org/protobuf v1.28.1
2122
gopkg.in/yaml.v3 v3.0.1
2223
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
212212
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
213213
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
214214
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
215+
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
216+
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
215217
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
216218
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
217219
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

internal/cmd/cmd.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
4444
rootCmd.SetIn(stdin)
4545
rootCmd.SetOut(stdout)
4646
rootCmd.SetErr(stderr)
47+
rootCmd.SilenceErrors = true
4748

4849
ctx := context.Background()
4950
if debug.Debug.Trace != "" {
@@ -55,9 +56,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
5556
ctx = tracectx
5657
defer cleanup()
5758
}
58-
5959
if err := rootCmd.ExecuteContext(ctx); err != nil {
60-
fmt.Fprintf(stderr, "%v\n", err)
6160
if exitError, ok := err.(*exec.ExitError); ok {
6261
return exitError.ExitCode()
6362
} else {

internal/cmd/generate.go

Lines changed: 81 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ import (
88
"io"
99
"os"
1010
"path/filepath"
11+
"runtime"
1112
"runtime/trace"
1213
"strings"
14+
"sync"
15+
16+
"golang.org/x/sync/errgroup"
1317

1418
"github.com/kyleconroy/sqlc/internal/codegen/golang"
1519
"github.com/kyleconroy/sqlc/internal/codegen/json"
@@ -159,71 +163,94 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
159163
}
160164
}
161165

162-
for _, sql := range pairs {
163-
combo := config.Combine(*conf, sql.SQL)
164-
if sql.Plugin != nil {
165-
combo.Codegen = *sql.Plugin
166-
}
166+
var m sync.Mutex
167+
grp, gctx := errgroup.WithContext(ctx)
168+
grp.SetLimit(runtime.GOMAXPROCS(0))
167169

168-
// TODO: This feels like a hack that will bite us later
169-
joined := make([]string, 0, len(sql.Schema))
170-
for _, s := range sql.Schema {
171-
joined = append(joined, filepath.Join(dir, s))
172-
}
173-
sql.Schema = joined
170+
stderrs := make([]bytes.Buffer, len(pairs))
174171

175-
joined = make([]string, 0, len(sql.Queries))
176-
for _, q := range sql.Queries {
177-
joined = append(joined, filepath.Join(dir, q))
178-
}
179-
sql.Queries = joined
172+
for i, pair := range pairs {
173+
sql := pair
174+
errout := &stderrs[i]
180175

181-
var name, lang string
182-
parseOpts := opts.Parser{
183-
Debug: debug.Debug,
184-
}
176+
grp.Go(func() error {
177+
combo := config.Combine(*conf, sql.SQL)
178+
if sql.Plugin != nil {
179+
combo.Codegen = *sql.Plugin
180+
}
185181

186-
switch {
187-
case sql.Gen.Go != nil:
188-
name = combo.Go.Package
189-
lang = "golang"
182+
// TODO: This feels like a hack that will bite us later
183+
joined := make([]string, 0, len(sql.Schema))
184+
for _, s := range sql.Schema {
185+
joined = append(joined, filepath.Join(dir, s))
186+
}
187+
sql.Schema = joined
190188

191-
case sql.Plugin != nil:
192-
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
193-
name = sql.Plugin.Plugin
194-
}
189+
joined = make([]string, 0, len(sql.Queries))
190+
for _, q := range sql.Queries {
191+
joined = append(joined, filepath.Join(dir, q))
192+
}
193+
sql.Queries = joined
195194

196-
packageRegion := trace.StartRegion(ctx, "package")
197-
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
195+
var name, lang string
196+
parseOpts := opts.Parser{
197+
Debug: debug.Debug,
198+
}
198199

199-
result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr)
200-
if failed {
201-
packageRegion.End()
202-
errored = true
203-
break
204-
}
200+
switch {
201+
case sql.Gen.Go != nil:
202+
name = combo.Go.Package
203+
lang = "golang"
205204

206-
out, resp, err := codegen(ctx, combo, sql, result)
207-
if err != nil {
208-
fmt.Fprintf(stderr, "# package %s\n", name)
209-
fmt.Fprintf(stderr, "error generating code: %s\n", err)
210-
errored = true
211-
packageRegion.End()
212-
continue
213-
}
205+
case sql.Plugin != nil:
206+
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
207+
name = sql.Plugin.Plugin
208+
}
214209

215-
files := map[string]string{}
216-
for _, file := range resp.Files {
217-
files[file.Name] = string(file.Contents)
218-
}
219-
for n, source := range files {
220-
filename := filepath.Join(dir, out, n)
221-
output[filename] = source
222-
}
223-
packageRegion.End()
224-
}
210+
packageRegion := trace.StartRegion(gctx, "package")
211+
trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
212+
213+
result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout)
214+
if failed {
215+
packageRegion.End()
216+
errored = true
217+
return nil
218+
}
219+
220+
out, resp, err := codegen(gctx, combo, sql, result)
221+
if err != nil {
222+
fmt.Fprintf(errout, "# package %s\n", name)
223+
fmt.Fprintf(errout, "error generating code: %s\n", err)
224+
errored = true
225+
packageRegion.End()
226+
return nil
227+
}
228+
229+
files := map[string]string{}
230+
for _, file := range resp.Files {
231+
files[file.Name] = string(file.Contents)
232+
}
225233

234+
m.Lock()
235+
for n, source := range files {
236+
filename := filepath.Join(dir, out, n)
237+
output[filename] = source
238+
}
239+
m.Unlock()
240+
241+
packageRegion.End()
242+
return nil
243+
})
244+
}
245+
if err := grp.Wait(); err != nil {
246+
return nil, err
247+
}
226248
if errored {
249+
for i, _ := range stderrs {
250+
if _, err := io.Copy(stderr, &stderrs[i]); err != nil {
251+
return nil, err
252+
}
253+
}
227254
return nil, fmt.Errorf("errored")
228255
}
229256
return output, nil

internal/ext/wasm/wasm.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"strings"
2121

2222
wasmtime "github.com/bytecodealliance/wasmtime-go/v3"
23+
"golang.org/x/sync/singleflight"
2324

2425
"github.com/kyleconroy/sqlc/internal/info"
2526
"github.com/kyleconroy/sqlc/internal/plugin"
@@ -49,6 +50,8 @@ type Runner struct {
4950
SHA256 string
5051
}
5152

53+
var flight singleflight.Group
54+
5255
// Verify the provided sha256 is valid.
5356
func (r *Runner) parseChecksum() (string, error) {
5457
if r.SHA256 == "" {
@@ -58,6 +61,24 @@ func (r *Runner) parseChecksum() (string, error) {
5861
}
5962

6063
func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
64+
expected, err := r.parseChecksum()
65+
if err != nil {
66+
return nil, err
67+
}
68+
value, err, _ := flight.Do(expected, func() (interface{}, error) {
69+
return r.loadSerializedModule(ctx, engine)
70+
})
71+
if err != nil {
72+
return nil, err
73+
}
74+
data, ok := value.([]byte)
75+
if !ok {
76+
return nil, fmt.Errorf("returned value was not a byte slice")
77+
}
78+
return wasmtime.NewModuleDeserialize(engine, data)
79+
}
80+
81+
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
6182
expected, err := r.parseChecksum()
6283
if err != nil {
6384
return nil, err
@@ -80,7 +101,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
80101
if err != nil {
81102
return nil, err
82103
}
83-
return wasmtime.NewModuleDeserialize(engine, data)
104+
return data, nil
84105
}
85106

86107
wmod, err := r.loadWASM(ctx, cache, expected)
@@ -95,21 +116,19 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
95116
return nil, fmt.Errorf("define wasi: %w", err)
96117
}
97118

98-
if staterr != nil {
99-
err := os.Mkdir(pluginDir, 0755)
100-
if err != nil && !os.IsExist(err) {
101-
return nil, fmt.Errorf("mkdirall: %w", err)
102-
}
103-
out, err := module.Serialize()
104-
if err != nil {
105-
return nil, fmt.Errorf("serialize: %w", err)
106-
}
107-
if err := os.WriteFile(modPath, out, 0444); err != nil {
108-
return nil, fmt.Errorf("cache wasm: %w", err)
109-
}
119+
err = os.Mkdir(pluginDir, 0755)
120+
if err != nil && !os.IsExist(err) {
121+
return nil, fmt.Errorf("mkdirall: %w", err)
122+
}
123+
out, err := module.Serialize()
124+
if err != nil {
125+
return nil, fmt.Errorf("serialize: %w", err)
126+
}
127+
if err := os.WriteFile(modPath, out, 0444); err != nil {
128+
return nil, fmt.Errorf("cache wasm: %w", err)
110129
}
111130

112-
return module, nil
131+
return out, nil
113132
}
114133

115134
func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {

0 commit comments

Comments
 (0)