Skip to content

cmd: Generate packages in parallel #2026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/pganalyze/pg_query_go/v2 v2.2.0
github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5
golang.org/x/sync v0.1.0
google.golang.org/protobuf v1.28.1
gopkg.in/yaml.v3 v3.0.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down
3 changes: 1 addition & 2 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
rootCmd.SetIn(stdin)
rootCmd.SetOut(stdout)
rootCmd.SetErr(stderr)
rootCmd.SilenceErrors = true
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we moved to RunE, we accidentally added some additional error output. Turning that off for now.


ctx := context.Background()
if debug.Debug.Trace != "" {
Expand All @@ -55,9 +56,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
ctx = tracectx
defer cleanup()
}

if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Fprintf(stderr, "%v\n", err)
if exitError, ok := err.(*exec.ExitError); ok {
return exitError.ExitCode()
} else {
Expand Down
135 changes: 81 additions & 54 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"
"sync"

"golang.org/x/sync/errgroup"

"github.com/kyleconroy/sqlc/internal/codegen/golang"
"github.com/kyleconroy/sqlc/internal/codegen/json"
Expand Down Expand Up @@ -159,71 +163,94 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
}
}

for _, sql := range pairs {
combo := config.Combine(*conf, sql.SQL)
if sql.Plugin != nil {
combo.Codegen = *sql.Plugin
}
var m sync.Mutex
grp, gctx := errgroup.WithContext(ctx)
grp.SetLimit(runtime.GOMAXPROCS(0))

// TODO: This feels like a hack that will bite us later
joined := make([]string, 0, len(sql.Schema))
for _, s := range sql.Schema {
joined = append(joined, filepath.Join(dir, s))
}
sql.Schema = joined
stderrs := make([]bytes.Buffer, len(pairs))

joined = make([]string, 0, len(sql.Queries))
for _, q := range sql.Queries {
joined = append(joined, filepath.Join(dir, q))
}
sql.Queries = joined
for i, pair := range pairs {
sql := pair
errout := &stderrs[i]

var name, lang string
parseOpts := opts.Parser{
Debug: debug.Debug,
}
grp.Go(func() error {
combo := config.Combine(*conf, sql.SQL)
if sql.Plugin != nil {
combo.Codegen = *sql.Plugin
}

switch {
case sql.Gen.Go != nil:
name = combo.Go.Package
lang = "golang"
// TODO: This feels like a hack that will bite us later
joined := make([]string, 0, len(sql.Schema))
for _, s := range sql.Schema {
joined = append(joined, filepath.Join(dir, s))
}
sql.Schema = joined

case sql.Plugin != nil:
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
name = sql.Plugin.Plugin
}
joined = make([]string, 0, len(sql.Queries))
for _, q := range sql.Queries {
joined = append(joined, filepath.Join(dir, q))
}
sql.Queries = joined

packageRegion := trace.StartRegion(ctx, "package")
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
var name, lang string
parseOpts := opts.Parser{
Debug: debug.Debug,
}

result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr)
if failed {
packageRegion.End()
errored = true
break
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If parsing failed, we'd return early. Instead, return errors from every package.

}
switch {
case sql.Gen.Go != nil:
name = combo.Go.Package
lang = "golang"

out, resp, err := codegen(ctx, combo, sql, result)
if err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
fmt.Fprintf(stderr, "error generating code: %s\n", err)
errored = true
packageRegion.End()
continue
}
case sql.Plugin != nil:
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
name = sql.Plugin.Plugin
}

files := map[string]string{}
for _, file := range resp.Files {
files[file.Name] = string(file.Contents)
}
for n, source := range files {
filename := filepath.Join(dir, out, n)
output[filename] = source
}
packageRegion.End()
}
packageRegion := trace.StartRegion(gctx, "package")
trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)

result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout)
if failed {
packageRegion.End()
errored = true
return nil
}

out, resp, err := codegen(gctx, combo, sql, result)
if err != nil {
fmt.Fprintf(errout, "# package %s\n", name)
fmt.Fprintf(errout, "error generating code: %s\n", err)
errored = true
packageRegion.End()
return nil
}

files := map[string]string{}
for _, file := range resp.Files {
files[file.Name] = string(file.Contents)
}

m.Lock()
for n, source := range files {
filename := filepath.Join(dir, out, n)
output[filename] = source
}
m.Unlock()
Comment on lines +234 to +239
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Writing to a map isn't threadsafe, so put this in mutex. In the future might just use sync.Map instead.


packageRegion.End()
return nil
})
}
if err := grp.Wait(); err != nil {
return nil, err
}
if errored {
for i, _ := range stderrs {
if _, err := io.Copy(stderr, &stderrs[i]); err != nil {
return nil, err
}
}
return nil, fmt.Errorf("errored")
}
return output, nil
Expand Down
47 changes: 33 additions & 14 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v3"
"golang.org/x/sync/singleflight"

"github.com/kyleconroy/sqlc/internal/info"
"github.com/kyleconroy/sqlc/internal/plugin"
Expand Down Expand Up @@ -49,6 +50,8 @@ type Runner struct {
SHA256 string
}

var flight singleflight.Group

// Verify the provided sha256 is valid.
func (r *Runner) parseChecksum() (string, error) {
if r.SHA256 == "" {
Expand All @@ -58,6 +61,24 @@ func (r *Runner) parseChecksum() (string, error) {
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine)
})
if err != nil {
return nil, err
}
data, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
Expand All @@ -80,7 +101,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
if err != nil {
return nil, err
}
return wasmtime.NewModuleDeserialize(engine, data)
return data, nil
}

wmod, err := r.loadWASM(ctx, cache, expected)
Expand All @@ -95,21 +116,19 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
return nil, fmt.Errorf("define wasi: %w", err)
}

if staterr != nil {
err := os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}
err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}

return module, nil
return out, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
Expand Down