From 5ef48f72c19c4cd641600d5ef48aac4bab3c6b64 Mon Sep 17 00:00:00 2001 From: jordan ryan reuter Date: Fri, 25 Nov 2022 13:45:30 -0500 Subject: [PATCH 1/4] internal/cmd: fix an exit code buglet It seems the previous code intended to interpret the error from rootCmd.ExecuteContext() as an *exec.ExitError. That's not the way it worked in practice, though. Instead, the typecheck on err.(*exec.ExitError) used the error from the call to tracer.Start(). This is due to the scoping rules of if statements. With this commit, cmd.Do() now attempts to interpret the command execution error as an *exec.ExitError. It also prints the error message to stderr, since the caller in ./cmd/sqlc.main doesn't do that. See: https://go.dev/ref/spec#Blocks --- internal/cmd/cmd.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index baa4b3d842..01c0c6ec00 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -52,13 +52,15 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int } defer cleanup() - if err := rootCmd.ExecuteContext(ctx); err == nil { - return 0 - } - if exitError, ok := err.(*exec.ExitError); ok { - return exitError.ExitCode() + if err := rootCmd.ExecuteContext(ctx); err != nil { + fmt.Fprintf(stderr, "%v\n", err) + if exitError, ok := err.(*exec.ExitError); ok { + return exitError.ExitCode() + } else { + return 1 + } } - return 1 + return 0 } var version string From 45a61f57db130c414ae9f2d9cbcc347e95ea723d Mon Sep 17 00:00:00 2001 From: jordan ryan reuter Date: Fri, 25 Nov 2022 14:23:55 -0500 Subject: [PATCH 2/4] internal/debug: remove Traced global Previously, some callers of trace.StartRegion() guarded on the global debug.Traced variable. This guard is unneeded because runtime/trace.StartRegion() will return a no-op region if tracing has not been started. That means we can centralize the decision of whether we want tracing to cmd.Do(). Since we then only have one reference to debug.Traced, we can just delete it, since it seems it was a shorthand for (debug.Debug.Trace != ""). See: https://cs.opensource.google/go/go/+/refs/tags/go1.19.3:src/runtime/trace/annotation.go;l=153-155 --- internal/cmd/cmd.go | 38 +++++++++++++++----------------------- internal/cmd/diff.go | 5 +---- internal/cmd/generate.go | 31 +++++++------------------------ internal/debug/dump.go | 2 -- internal/tracer/trace.go | 13 ++++++------- 5 files changed, 29 insertions(+), 60 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 01c0c6ec00..6489110360 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -45,12 +45,16 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.SetOut(stdout) rootCmd.SetErr(stderr) - ctx, cleanup, err := tracer.Start(context.Background()) - if err != nil { - fmt.Printf("failed to start trace: %v\n", err) - return 1 + ctx := context.Background() + if debug.Debug.Trace != "" { + tracectx, cleanup, err := tracer.Start(ctx) + if err != nil { + fmt.Printf("failed to start trace: %v\n", err) + return 1 + } + ctx = tracectx + defer cleanup() } - defer cleanup() if err := rootCmd.ExecuteContext(ctx); err != nil { fmt.Fprintf(stderr, "%v\n", err) @@ -69,9 +73,7 @@ var versionCmd = &cobra.Command{ Use: "version", Short: "Print the sqlc version number", Run: func(cmd *cobra.Command, args []string) { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "version").End() - } + defer trace.StartRegion(cmd.Context(), "version").End() if version == "" { fmt.Printf("%s\n", info.Version) } else { @@ -84,9 +86,7 @@ var initCmd = &cobra.Command{ Use: "init", Short: "Create an empty sqlc.yaml settings file", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "init").End() - } + defer trace.StartRegion(cmd.Context(), "init").End() file := "sqlc.yaml" if f := cmd.Flag("file"); f != nil && f.Changed { file = f.Value.String() @@ -156,18 +156,14 @@ var genCmd = &cobra.Command{ Use: "generate", Short: "Generate Go code from SQL", Run: func(cmd *cobra.Command, args []string) { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "generate").End() - } + defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) output, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr) if err != nil { os.Exit(1) } - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "writefiles").End() - } + defer trace.StartRegion(cmd.Context(), "writefiles").End() for filename, source := range output { os.MkdirAll(filepath.Dir(filename), 0755) if err := os.WriteFile(filename, []byte(source), 0644); err != nil { @@ -196,9 +192,7 @@ var checkCmd = &cobra.Command{ Use: "compile", Short: "Statically check SQL for syntax and type errors", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "compile").End() - } + defer trace.StartRegion(cmd.Context(), "compile").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) if _, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil { @@ -241,9 +235,7 @@ var diffCmd = &cobra.Command{ Use: "diff", Short: "Compare the generated files to the existing files", RunE: func(cmd *cobra.Command, args []string) error { - if debug.Traced { - defer trace.StartRegion(cmd.Context(), "diff").End() - } + defer trace.StartRegion(cmd.Context(), "diff").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) if err := Diff(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil { diff --git a/internal/cmd/diff.go b/internal/cmd/diff.go index dbab4c6ed8..aa1ddc8788 100644 --- a/internal/cmd/diff.go +++ b/internal/cmd/diff.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/cubicdaiya/gonp" - "github.com/kyleconroy/sqlc/internal/debug" ) func Diff(ctx context.Context, e Env, dir, name string, stderr io.Writer) error { @@ -19,9 +18,7 @@ func Diff(ctx context.Context, e Env, dir, name string, stderr io.Writer) error if err != nil { return err } - if debug.Traced { - defer trace.StartRegion(ctx, "checkfiles").End() - } + defer trace.StartRegion(ctx, "checkfiles").End() var errored bool keys := make([]string, 0, len(output)) diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 602974b9cc..450cb9a60e 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -193,17 +193,12 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer name = sql.Plugin.Plugin } - var packageRegion *trace.Region - if debug.Traced { - packageRegion = trace.StartRegion(ctx, "package") - trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) - } + packageRegion := trace.StartRegion(ctx, "package") + trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang) result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr) if failed { - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() errored = true break } @@ -213,9 +208,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer fmt.Fprintf(stderr, "# package %s\n", name) fmt.Fprintf(stderr, "error generating code: %s\n", err) errored = true - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() continue } @@ -227,9 +220,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer filename := filepath.Join(dir, out, n) output[filename] = source } - if packageRegion != nil { - packageRegion.End() - } + packageRegion.End() } if errored { @@ -239,9 +230,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer } func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { - if debug.Traced { - defer trace.StartRegion(ctx, "parse").End() - } + defer trace.StartRegion(ctx, "parse").End() c := compiler.NewCompiler(sql, combo) if err := c.ParseCatalog(sql.Schema); err != nil { fmt.Fprintf(stderr, "# package %s\n", name) @@ -272,10 +261,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) { - var region *trace.Region - if debug.Traced { - region = trace.StartRegion(ctx, "codegen") - } + defer trace.StartRegion(ctx, "codegen").End() req := codeGenRequest(result, combo) var handler ext.Handler var out string @@ -319,8 +305,5 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re return "", nil, fmt.Errorf("missing language backend") } resp, err := handler.Generate(ctx, req) - if region != nil { - region.End() - } return out, resp, err } diff --git a/internal/debug/dump.go b/internal/debug/dump.go index 5f95a7fe43..d2e25bab7c 100644 --- a/internal/debug/dump.go +++ b/internal/debug/dump.go @@ -9,14 +9,12 @@ import ( ) var Active bool -var Traced bool var Debug opts.Debug func init() { Active = os.Getenv("SQLCDEBUG") != "" if Active { Debug = opts.DebugFromEnv() - Traced = Debug.Trace != "" } } diff --git a/internal/tracer/trace.go b/internal/tracer/trace.go index 38bf6ace9e..d0c265f1c7 100644 --- a/internal/tracer/trace.go +++ b/internal/tracer/trace.go @@ -9,18 +9,17 @@ import ( "github.com/kyleconroy/sqlc/internal/debug" ) -func Start(base context.Context) (context.Context, func(), error) { - if !debug.Traced { - return base, func() {}, nil - } - +// Start starts Go's runtime tracing facility. +// Traces will be written to the file named by [debug.Debug.Trace]. +// It also starts a new [*trace.Task] that will be stopped when the cleanup is called. +func Start(base context.Context) (_ context.Context, cleanup func(), _ error) { f, err := os.Create(debug.Debug.Trace) if err != nil { - return base, func() {}, fmt.Errorf("failed to create trace output file: %v", err) + return base, cleanup, fmt.Errorf("failed to create trace output file: %v", err) } if err := trace.Start(f); err != nil { - return base, func() {}, fmt.Errorf("failed to start trace: %v", err) + return base, cleanup, fmt.Errorf("failed to start trace: %v", err) } ctx, task := trace.NewTask(base, "sqlc") From c9a94814e58839098993e439ff393120f0544840 Mon Sep 17 00:00:00 2001 From: jordan ryan reuter Date: Fri, 25 Nov 2022 14:48:34 -0500 Subject: [PATCH 3/4] internal/codegen: don't copy protobuf locks This patch silences the output of "go vet ./...", which used to say this: # github.com/kyleconroy/sqlc/internal/codegen/json internal/codegen/json/gen.go:24:12: return copies lock value: github.com/kyleconroy/sqlc/internal/plugin.JSONCode contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex internal/codegen/json/gen.go:26:11: return copies lock value: github.com/kyleconroy/sqlc/internal/plugin.JSONCode contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex internal/codegen/json/gen.go:30:10: return copies lock value: github.com/kyleconroy/sqlc/internal/plugin.JSONCode contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex # github.com/kyleconroy/sqlc/internal/codegen/golang internal/codegen/golang/imports.go:68:9: range var strct copies lock: github.com/kyleconroy/sqlc/internal/codegen/golang.Struct contains github.com/kyleconroy/sqlc/internal/plugin.Identifier contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex internal/codegen/golang/result.go:94:30: call of append copies lock value: github.com/kyleconroy/sqlc/internal/codegen/golang.Struct contains github.com/kyleconroy/sqlc/internal/plugin.Identifier contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex internal/codegen/golang/result.go:208:11: range var s copies lock: github.com/kyleconroy/sqlc/internal/codegen/golang.Struct contains github.com/kyleconroy/sqlc/internal/plugin.Identifier contains google.golang.org/protobuf/internal/impl.MessageState contains sync.Mutex Also, add go vet to CI. --- Makefile | 5 ++++- internal/codegen/golang/result.go | 4 ++-- internal/codegen/golang/struct.go | 2 +- internal/codegen/json/gen.go | 10 +++++----- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index d6111e1c6c..d393ad90b8 100644 --- a/Makefile +++ b/Makefile @@ -9,13 +9,16 @@ install: test: go test ./... +vet: + go vet ./... + test-examples: go test --tags=examples ./... build-endtoend: cd ./internal/endtoend/testdata && go build ./... -test-ci: test-examples build-endtoend +test-ci: test-examples build-endtoend vet regen: sqlc-dev sqlc-gen-json go run ./scripts/regenerate/ diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 640132ca92..efba759adb 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -71,7 +71,7 @@ func buildStructs(req *plugin.CodeGenRequest) []Struct { }) } s := Struct{ - Table: plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, + Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, Name: StructName(structName, req.Settings), Comment: table.Comment, } @@ -214,7 +214,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) c := query.Columns[i] sameName := f.Name == StructName(columnName(c, i), req.Settings) sameType := f.Type == goType(req, c) - sameTable := sdk.SameTableName(c.Table, &s.Table, req.Catalog.DefaultSchema) + sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false } diff --git a/internal/codegen/golang/struct.go b/internal/codegen/golang/struct.go index f72a228ae3..c1dfd5663d 100644 --- a/internal/codegen/golang/struct.go +++ b/internal/codegen/golang/struct.go @@ -9,7 +9,7 @@ import ( ) type Struct struct { - Table plugin.Identifier + Table *plugin.Identifier Name string Fields []Field Comment string diff --git a/internal/codegen/json/gen.go b/internal/codegen/json/gen.go index f481d009c6..75ab3941cf 100644 --- a/internal/codegen/json/gen.go +++ b/internal/codegen/json/gen.go @@ -11,13 +11,13 @@ import ( "github.com/kyleconroy/sqlc/internal/plugin" ) -func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) { +func parseOptions(req *plugin.CodeGenRequest) (*plugin.JSONCode, error) { if req.Settings == nil { - return plugin.JSONCode{}, nil + return new(plugin.JSONCode), nil } if req.Settings.Codegen != nil { if len(req.Settings.Codegen.Options) != 0 { - var options plugin.JSONCode + var options *plugin.JSONCode dec := ejson.NewDecoder(bytes.NewReader(req.Settings.Codegen.Options)) dec.DisallowUnknownFields() if err := dec.Decode(&options); err != nil { @@ -27,9 +27,9 @@ func parseOptions(req *plugin.CodeGenRequest) (plugin.JSONCode, error) { } } if req.Settings.Json != nil { - return *req.Settings.Json, nil + return req.Settings.Json, nil } - return plugin.JSONCode{}, nil + return new(plugin.JSONCode), nil } func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) { From e2044d5206116f83458e73a77b5ff7f1213c43be Mon Sep 17 00:00:00 2001 From: jordan ryan reuter Date: Fri, 25 Nov 2022 15:04:01 -0500 Subject: [PATCH 4/4] internal/cmd: use *cobra.Command.RunE everywhere Errors are handled by the root command's execution, so we don't need to os.Exit() in any child commands. --- internal/cmd/cmd.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 6489110360..204da3212d 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -72,13 +72,14 @@ var version string var versionCmd = &cobra.Command{ Use: "version", Short: "Print the sqlc version number", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { defer trace.StartRegion(cmd.Context(), "version").End() if version == "" { fmt.Printf("%s\n", info.Version) } else { fmt.Printf("%s\n", version) } + return nil }, } @@ -155,22 +156,23 @@ func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) { var genCmd = &cobra.Command{ Use: "generate", Short: "Generate Go code from SQL", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { defer trace.StartRegion(cmd.Context(), "generate").End() stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) output, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr) if err != nil { - os.Exit(1) + return err } defer trace.StartRegion(cmd.Context(), "writefiles").End() for filename, source := range output { os.MkdirAll(filepath.Dir(filename), 0755) if err := os.WriteFile(filename, []byte(source), 0644); err != nil { fmt.Fprintf(stderr, "%s: %s\n", filename, err) - os.Exit(1) + return err } } + return nil }, }