diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 8d7a64c8dd..9a6584b522 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -39,7 +39,6 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int rootCmd.PersistentFlags().BoolP("experimental", "x", false, "DEPRECATED: enable experimental features (default: false)") rootCmd.PersistentFlags().Bool("no-remote", false, "disable remote execution (default: false)") rootCmd.PersistentFlags().Bool("remote", false, "enable remote execution (default: false)") - rootCmd.PersistentFlags().Bool("no-database", false, "disable database connections (default: false)") rootCmd.AddCommand(checkCmd) rootCmd.AddCommand(createDBCmd) @@ -137,24 +136,21 @@ var initCmd = &cobra.Command{ } type Env struct { - DryRun bool - Debug opts.Debug - Remote bool - NoRemote bool - NoDatabase bool + DryRun bool + Debug opts.Debug + Remote bool + NoRemote bool } func ParseEnv(c *cobra.Command) Env { dr := c.Flag("dry-run") r := c.Flag("remote") nr := c.Flag("no-remote") - nodb := c.Flag("no-database") return Env{ - DryRun: dr != nil && dr.Changed, - Debug: opts.DebugFromEnv(), - Remote: r != nil && nr.Value.String() == "true", - NoRemote: nr != nil && nr.Value.String() == "true", - NoDatabase: nodb != nil && nodb.Value.String() == "true", + DryRun: dr != nil && dr.Changed, + Debug: opts.DebugFromEnv(), + Remote: r != nil && nr.Value.String() == "true", + NoRemote: nr != nil && nr.Value.String() == "true", } } diff --git a/internal/cmd/vet.go b/internal/cmd/vet.go index bb56c7de33..2d3e1c24b3 100644 --- a/internal/cmd/vet.go +++ b/internal/cmd/vet.go @@ -144,13 +144,13 @@ func Vet(ctx context.Context, dir, filename string, opts *Options) error { } c := checker{ - Rules: rules, - Conf: conf, - Dir: dir, - Env: env, - Envmap: map[string]string{}, - Stderr: stderr, - NoDatabase: e.NoDatabase, + Rules: rules, + Conf: conf, + Dir: dir, + Env: env, + Stderr: stderr, + OnlyManagedDB: e.Debug.OnlyManagedDatabases, + Replacer: shfmt.NewReplacer(nil), } errored := false for _, sql := range conf.SQL { @@ -379,14 +379,14 @@ type rule struct { } type checker struct { - Rules map[string]rule - Conf *config.Config - Dir string - Env *cel.Env - Envmap map[string]string - Stderr io.Writer - NoDatabase bool - Client pb.QuickClient + Rules map[string]rule + Conf *config.Config + Dir string + Env *cel.Env + Stderr io.Writer + OnlyManagedDB bool + Client pb.QuickClient + Replacer *shfmt.Replacer } func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, func() error, error) { @@ -448,14 +448,7 @@ func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, f } func (c *checker) DSN(dsn string) (string, error) { - // Populate the environment variable map if it is empty - if len(c.Envmap) == 0 { - for _, e := range os.Environ() { - k, v, _ := strings.Cut(e, "=") - c.Envmap[k] = v - } - } - return shfmt.Replace(dsn, c.Envmap), nil + return c.Replacer.Replace(dsn), nil } func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { @@ -488,8 +481,8 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error { var prep preparer var expl explainer if s.Database != nil { // TODO only set up a database connection if a rule evaluation requires it - if c.NoDatabase { - return fmt.Errorf("database: connections disabled via command line flag") + if s.Database.URI != "" && c.OnlyManagedDB { + return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") } dburl, cleanup, err := c.fetchDatabaseUri(ctx, s) if err != nil { diff --git a/internal/engine/postgresql/analyzer/analyze.go b/internal/engine/postgresql/analyzer/analyze.go index 9c1e77f655..7a3a53892c 100644 --- a/internal/engine/postgresql/analyzer/analyze.go +++ b/internal/engine/postgresql/analyzer/analyze.go @@ -13,26 +13,31 @@ import ( core "github.com/sqlc-dev/sqlc/internal/analysis" "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/opts" pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1" + "github.com/sqlc-dev/sqlc/internal/shfmt" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) type Analyzer struct { - db config.Database - client pb.QuickClient - pool *pgxpool.Pool - - formats sync.Map - columns sync.Map - tables sync.Map + db config.Database + client pb.QuickClient + pool *pgxpool.Pool + dbg opts.Debug + replacer *shfmt.Replacer + formats sync.Map + columns sync.Map + tables sync.Map } func New(client pb.QuickClient, db config.Database) *Analyzer { return &Analyzer{ - db: db, - client: client, + db: db, + dbg: opts.DebugFromEnv(), + client: client, + replacer: shfmt.NewReplacer(nil), } } @@ -204,8 +209,10 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat return nil, err } uri = edb.Uri + } else if a.dbg.OnlyManagedDatabases { + return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") } else { - uri = a.db.URI + uri = a.replacer.Replace(a.db.URI) } conf, err := pgxpool.ParseConfig(uri) if err != nil { diff --git a/internal/opts/debug.go b/internal/opts/debug.go index 94cda87863..b92cbd4ae8 100644 --- a/internal/opts/debug.go +++ b/internal/opts/debug.go @@ -12,18 +12,20 @@ import ( // dumpcatalog: setting dumpcatalog=1 will print the parsed database schema // trace: setting trace= will output a trace // processplugins: setting processplugins=0 will disable process-based plugins +// databases: setting databases=managed will disable connections to databases via URI // dumpvetenv: setting dumpvetenv=1 will print the variables available to // a vet rule during evaluation // dumpexplain: setting dumpexplain=1 will print the JSON-formatted output // from executing EXPLAIN ... on a query during vet rule evaluation type Debug struct { - DumpAST bool - DumpCatalog bool - Trace string - ProcessPlugins bool - DumpVetEnv bool - DumpExplain bool + DumpAST bool + DumpCatalog bool + Trace string + ProcessPlugins bool + OnlyManagedDatabases bool + DumpVetEnv bool + DumpExplain bool } func DebugFromEnv() Debug { @@ -53,6 +55,8 @@ func DebugFromString(val string) Debug { } case pair == "processplugins=0": d.ProcessPlugins = false + case pair == "databases=managed": + d.OnlyManagedDatabases = true case pair == "dumpvetenv=1": d.DumpVetEnv = true case pair == "dumpexplain=1": diff --git a/internal/shfmt/shfmt.go b/internal/shfmt/shfmt.go index a3f1c5bbff..88f3074b71 100644 --- a/internal/shfmt/shfmt.go +++ b/internal/shfmt/shfmt.go @@ -1,16 +1,38 @@ package shfmt import ( + "os" "regexp" "strings" ) var pat = regexp.MustCompile(`\$\{[A-Z_]+\}`) -func Replace(f string, vars map[string]string) string { +type Replacer struct { + envmap map[string]string +} + +func (r *Replacer) Replace(f string) string { return pat.ReplaceAllStringFunc(f, func(s string) string { s = strings.TrimPrefix(s, "${") s = strings.TrimSuffix(s, "}") - return vars[s] + return r.envmap[s] }) } + +func NewReplacer(env []string) *Replacer { + r := Replacer{ + envmap: map[string]string{}, + } + if env == nil { + env = os.Environ() + } + for _, e := range env { + k, v, _ := strings.Cut(e, "=") + if k == "SQLC_AUTH_TOKEN" { + continue + } + r.envmap[k] = v + } + return &r +} diff --git a/internal/shfmt/shfmt_test.go b/internal/shfmt/shfmt_test.go index ce5c29ea5a..cf99ed220b 100644 --- a/internal/shfmt/shfmt_test.go +++ b/internal/shfmt/shfmt_test.go @@ -4,14 +4,14 @@ import "testing" func TestReplace(t *testing.T) { s := "POSTGRES_SQL://${PG_USER}:${PG_PASSWORD}@${PG_HOST}:${PG_PORT}/AUTHORS" - env := map[string]string{ - "PG_USER": "user", - "PG_PASSWORD": "password", - "PG_HOST": "host", - "PG_PORT": "port", - } + r := NewReplacer([]string{ + "PG_USER=user", + "PG_PASSWORD=password", + "PG_HOST=host", + "PG_PORT=port", + }) e := "POSTGRES_SQL://user:password@host:port/AUTHORS" - if v := Replace(s, env); v != e { + if v := r.Replace(s); v != e { t.Errorf("%s != %s", v, e) } }