diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 1193348342..fc7d5467fd 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -1,17 +1,21 @@ package cmd import ( + "context" "fmt" "io" "os" "os/exec" "path/filepath" + "runtime/trace" "github.com/spf13/cobra" "github.com/spf13/pflag" yaml "gopkg.in/yaml.v3" "github.com/kyleconroy/sqlc/internal/config" + "github.com/kyleconroy/sqlc/internal/debug" + "github.com/kyleconroy/sqlc/internal/tracer" ) // Do runs the command logic. @@ -30,8 +34,14 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.SetOut(stdout) rootCmd.SetErr(stderr) - err := rootCmd.Execute() - if err == nil { + ctx, cleanup, err := tracer.Start(context.Background()) + if err != nil { + fmt.Printf("failed to start trace: %v\n", err) + return 1 + } + defer cleanup() + + if err := rootCmd.ExecuteContext(ctx); err == nil { return 0 } if exitError, ok := err.(*exec.ExitError); ok { @@ -46,6 +56,9 @@ var versionCmd = &cobra.Command{ Use: "version", Short: "Print the sqlc version number", Run: func(cmd *cobra.Command, args []string) { + if debug.Traced { + defer trace.StartRegion(cmd.Context(), "version").End() + } if version == "" { // When no version is set, return the next bug fix version // after the most recent tag @@ -60,6 +73,9 @@ var initCmd = &cobra.Command{ Use: "init", Short: "Create an empty sqlc.yaml settings file", RunE: func(cmd *cobra.Command, args []string) error { + if debug.Traced { + defer trace.StartRegion(cmd.Context(), "init").End() + } file := "sqlc.yaml" if f := cmd.Flag("file"); f != nil && f.Changed { file = f.Value.String() @@ -114,12 +130,18 @@ var genCmd = &cobra.Command{ Use: "generate", Short: "Generate Go code from SQL", Run: func(cmd *cobra.Command, args []string) { + if debug.Traced { + defer trace.StartRegion(cmd.Context(), "generate").End() + } stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - output, err := Generate(ParseEnv(cmd), dir, name, stderr) + output, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr) if err != nil { os.Exit(1) } + if debug.Traced { + defer trace.StartRegion(cmd.Context(), "writefiles").End() + } for filename, source := range output { os.MkdirAll(filepath.Dir(filename), 0755) if err := os.WriteFile(filename, []byte(source), 0644); err != nil { @@ -134,9 +156,12 @@ var checkCmd = &cobra.Command{ Use: "compile", Short: "Statically check SQL for syntax and type errors", RunE: func(cmd *cobra.Command, args []string) error { + if debug.Traced { + defer trace.StartRegion(cmd.Context(), "compile").End() + } stderr := cmd.ErrOrStderr() dir, name := getConfigPath(stderr, cmd.Flag("file")) - if _, err := Generate(ParseEnv(cmd), dir, name, stderr); err != nil { + if _, err := Generate(cmd.Context(), ParseEnv(cmd), dir, name, stderr); err != nil { os.Exit(1) } return nil diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 1cf5e95927..0ed7a4745c 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -2,11 +2,13 @@ package cmd import ( "bytes" + "context" "errors" "fmt" "io" "os" "path/filepath" + "runtime/trace" "strings" "github.com/kyleconroy/sqlc/internal/codegen/golang" @@ -44,7 +46,7 @@ type outPair struct { config.SQL } -func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, error) { +func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer) (map[string]string, error) { configPath := "" if filename != "" { configPath = filepath.Join(dir, filename) @@ -97,12 +99,6 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, return nil, err } - debug, err := opts.DebugFromEnv() - if err != nil { - fmt.Fprintf(stderr, "error parsing SQLCDEBUG: %s\n", err) - return nil, err - } - output := map[string]string{} errored := false @@ -148,27 +144,43 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, } sql.Queries = joined - var name string + var name, lang string parseOpts := opts.Parser{ - Debug: debug, + Debug: debug.Debug, } if sql.Gen.Go != nil { name = combo.Go.Package + lang = "golang" } else if sql.Gen.Kotlin != nil { if sql.Engine == config.EnginePostgreSQL { parseOpts.UsePositionalParameters = true } + lang = "kotlin" name = combo.Kotlin.Package } else if sql.Gen.Python != nil { + lang = "python" name = combo.Python.Package } - result, failed := parse(e, name, dir, sql.SQL, combo, parseOpts, stderr) + var packageRegion *trace.Region + if debug.Traced { + packageRegion = trace.StartRegion(ctx, "package") + trace.Logf(ctx, "", "name=%s dir=%s language=%s", name, dir, lang) + } + + result, failed := parse(ctx, e, name, dir, sql.SQL, combo, parseOpts, stderr) if failed { + if packageRegion != nil { + packageRegion.End() + } errored = true break } + var region *trace.Region + if debug.Traced { + region = trace.StartRegion(ctx, "codegen") + } var files map[string]string var out string switch { @@ -184,17 +196,26 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, default: panic("missing language backend") } + if region != nil { + region.End() + } if err != nil { fmt.Fprintf(stderr, "# package %s\n", name) fmt.Fprintf(stderr, "error generating code: %s\n", err) errored = true + if packageRegion != nil { + packageRegion.End() + } continue } for n, source := range files { filename := filepath.Join(dir, out, n) output[filename] = source } + if packageRegion != nil { + packageRegion.End() + } } if errored { @@ -203,7 +224,10 @@ func Generate(e Env, dir, filename string, stderr io.Writer) (map[string]string, return output, nil } -func parse(e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { +func parse(ctx context.Context, e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) { + if debug.Traced { + defer trace.StartRegion(ctx, "parse").End() + } c := compiler.NewCompiler(sql, combo) if err := c.ParseCatalog(sql.Schema); err != nil { fmt.Fprintf(stderr, "# package %s\n", name) diff --git a/internal/debug/dump.go b/internal/debug/dump.go index 8eb4b83a31..5f95a7fe43 100644 --- a/internal/debug/dump.go +++ b/internal/debug/dump.go @@ -4,12 +4,20 @@ import ( "os" "github.com/davecgh/go-spew/spew" + + "github.com/kyleconroy/sqlc/internal/opts" ) var Active bool +var Traced bool +var Debug opts.Debug func init() { Active = os.Getenv("SQLCDEBUG") != "" + if Active { + Debug = opts.DebugFromEnv() + Traced = Debug.Trace != "" + } } func Dump(n ...interface{}) { diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 8d3dad0058..e3756b22f3 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "os" "path/filepath" "strings" @@ -15,6 +16,8 @@ import ( func TestExamples(t *testing.T) { t.Parallel() + ctx := context.Background() + examples, err := filepath.Abs(filepath.Join("..", "..", "examples")) if err != nil { t.Fatal(err) @@ -34,7 +37,7 @@ func TestExamples(t *testing.T) { t.Parallel() path := filepath.Join(examples, tc) var stderr bytes.Buffer - output, err := cmd.Generate(cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) + output, err := cmd.Generate(ctx, cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) if err != nil { t.Fatalf("sqlc generate failed: %s", stderr.String()) } @@ -44,6 +47,7 @@ func TestExamples(t *testing.T) { } func BenchmarkExamples(b *testing.B) { + ctx := context.Background() examples, err := filepath.Abs(filepath.Join("..", "..", "examples")) if err != nil { b.Fatal(err) @@ -61,7 +65,7 @@ func BenchmarkExamples(b *testing.B) { path := filepath.Join(examples, tc) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - cmd.Generate(cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) + cmd.Generate(ctx, cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) } }) } @@ -69,6 +73,7 @@ func BenchmarkExamples(b *testing.B) { func TestReplay(t *testing.T) { t.Parallel() + ctx := context.Background() var dirs []string err := filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error { if err != nil { @@ -90,7 +95,7 @@ func TestReplay(t *testing.T) { path, _ := filepath.Abs(tc) var stderr bytes.Buffer expected := expectedStderr(t, path) - output, err := cmd.Generate(cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) + output, err := cmd.Generate(ctx, cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) if len(expected) == 0 && err != nil { t.Fatalf("sqlc generate failed: %s", stderr.String()) } @@ -167,6 +172,7 @@ func expectedStderr(t *testing.T, dir string) string { } func BenchmarkReplay(b *testing.B) { + ctx := context.Background() var dirs []string err := filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error { if err != nil { @@ -187,7 +193,7 @@ func BenchmarkReplay(b *testing.B) { path, _ := filepath.Abs(tc) for i := 0; i < b.N; i++ { var stderr bytes.Buffer - cmd.Generate(cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) + cmd.Generate(ctx, cmd.Env{ExperimentalFeatures: true}, path, "", &stderr) } }) } diff --git a/internal/opts/debug.go b/internal/opts/debug.go index ac0dcc4783..7acfddd161 100644 --- a/internal/opts/debug.go +++ b/internal/opts/debug.go @@ -10,25 +10,35 @@ import ( // // dumpast: setting dumpast=1 will print the AST of every SQL statement // dumpcatalog: setting dumpcatalog=1 will print the parsed database schema +// trace: setting trace= will output a trace type Debug struct { DumpAST bool DumpCatalog bool + Trace string } -func DebugFromEnv() (Debug, error) { +func DebugFromEnv() Debug { d := Debug{} val := os.Getenv("SQLCDEBUG") if val == "" { - return d, nil + return d } for _, pair := range strings.Split(val, ",") { - switch strings.TrimSpace(pair) { - case "dumpast=1": + pair = strings.TrimSpace(pair) + switch { + case pair == "dumpast=1": d.DumpAST = true - case "dumpcatalog=1": + case pair == "dumpcatalog=1": d.DumpCatalog = true + case strings.HasPrefix(pair, "trace="): + traceName := strings.TrimPrefix(pair, "trace=") + if traceName == "1" { + d.Trace = "trace.out" + } else { + d.Trace = traceName + } } } - return d, nil + return d } diff --git a/internal/tracer/trace.go b/internal/tracer/trace.go new file mode 100644 index 0000000000..38bf6ace9e --- /dev/null +++ b/internal/tracer/trace.go @@ -0,0 +1,33 @@ +package tracer + +import ( + "context" + "fmt" + "os" + "runtime/trace" + + "github.com/kyleconroy/sqlc/internal/debug" +) + +func Start(base context.Context) (context.Context, func(), error) { + if !debug.Traced { + return base, func() {}, nil + } + + f, err := os.Create(debug.Debug.Trace) + if err != nil { + return base, func() {}, fmt.Errorf("failed to create trace output file: %v", err) + } + + if err := trace.Start(f); err != nil { + return base, func() {}, fmt.Errorf("failed to start trace: %v", err) + } + + ctx, task := trace.NewTask(base, "sqlc") + + return ctx, func() { + defer f.Close() + defer trace.Stop() + defer task.End() + }, nil +}