diff --git a/internal/config/config.go b/internal/config/config.go index 8aaae365dc..1c81ff7d8b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -81,6 +81,7 @@ type SQLGo struct { EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries":` + EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file":` Package string `json:"package" yaml:"package"` Out string `json:"out" yaml:"out"` Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` diff --git a/internal/config/v_one.go b/internal/config/v_one.go index 0826914399..03b7f0cc13 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -24,6 +24,7 @@ type v1PackageSettings struct { EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` + EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file"` Overrides []Override `json:"overrides" yaml:"overrides"` } @@ -103,6 +104,7 @@ func (c *V1GenerateSettings) Translate() Config { EmitInterface: pkg.EmitInterface, EmitJSONTags: pkg.EmitJSONTags, EmitPreparedQueries: pkg.EmitPreparedQueries, + EmitSingleFile: pkg.EmitSingleFile, Package: pkg.Name, Out: pkg.Path, Overrides: pkg.Overrides, diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index bf50c58cf1..b6251d64bc 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -188,29 +188,74 @@ func UsesArrays(r Generateable, settings config.CombinedSettings) bool { return false } +type fileImports struct { + Std []string + Dep []string +} + +func mergeImports(imps ...fileImports) [][]string { + if len(imps) == 1 { + return [][]string{imps[0].Std, imps[0].Dep} + } + + var stds, pkgs []string + seenStd := map[string]struct{}{} + seenPkg := map[string]struct{}{} + for i := range imps { + for _, std := range imps[i].Std { + if _, ok := seenStd[std]; ok { + continue + } + stds = append(stds, std) + seenStd[std] = struct{}{} + } + for _, pkg := range imps[i].Dep { + if _, ok := seenPkg[pkg]; ok { + continue + } + pkgs = append(pkgs, pkg) + seenPkg[pkg] = struct{}{} + } + } + return [][]string{stds, pkgs} +} + func Imports(r Generateable, settings config.CombinedSettings) func(string) [][]string { return func(filename string) [][]string { + if filename == "all.go" { + var imps []fileImports + imps = append(imps, dbImports(r, settings)) + imps = append(imps, modelImports(r, settings)) + imps = append(imps, interfaceImports(r, settings)) + imps = append(imps, queryImports(r, settings, filename)) + return mergeImports(imps...) + } + if filename == "db.go" { - imps := []string{"context", "database/sql"} - if settings.Go.EmitPreparedQueries { - imps = append(imps, "fmt") - } - return [][]string{imps} + return mergeImports(dbImports(r, settings)) } if filename == "models.go" { - return ModelImports(r, settings) + return mergeImports(modelImports(r, settings)) } if filename == "querier.go" { - return InterfaceImports(r, settings) + return mergeImports(interfaceImports(r, settings)) } - return QueryImports(r, settings, filename) + return mergeImports(queryImports(r, settings, filename)) } } -func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]string { +func dbImports(r Generateable, settings config.CombinedSettings) fileImports { + std := []string{"context", "database/sql"} + if settings.Go.EmitPreparedQueries { + std = append(std, "fmt") + } + return fileImports{Std: std} +} + +func interfaceImports(r Generateable, settings config.CombinedSettings) fileImports { gq := r.GoQueries(settings) uses := func(name string) bool { for _, q := range gq { @@ -284,10 +329,10 @@ func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]stri sort.Strings(stds) sort.Strings(pkgs) - return [][]string{stds, pkgs} + return fileImports{stds, pkgs} } -func ModelImports(r Generateable, settings config.CombinedSettings) [][]string { +func modelImports(r Generateable, settings config.CombinedSettings) fileImports { std := make(map[string]struct{}) if UsesType(r, "sql.Null", settings) { std["database/sql"] = struct{}{} @@ -343,10 +388,10 @@ func ModelImports(r Generateable, settings config.CombinedSettings) [][]string { sort.Strings(stds) sort.Strings(pkgs) - return [][]string{stds, pkgs} + return fileImports{stds, pkgs} } -func QueryImports(r Generateable, settings config.CombinedSettings, filename string) [][]string { +func queryImports(r Generateable, settings config.CombinedSettings, filename string) fileImports { // for _, strct := range r.Structs() { // for _, f := range strct.Fields { // if strings.HasPrefix(f.Type, "[]") { @@ -356,7 +401,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str // } var gq []GoQuery for _, query := range r.GoQueries(settings) { - if query.SourceName == filename { + if query.SourceName == filename || settings.Go.EmitSingleFile { gq = append(gq, query) } } @@ -481,7 +526,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str sort.Strings(stds) sort.Strings(pkgs) - return [][]string{stds, pkgs} + return fileImports{stds, pkgs} } func enumValueName(value string) string { @@ -924,7 +969,8 @@ func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery { return qs } -var dbTmpl = `// Code generated by sqlc. DO NOT EDIT. +var templateSet = ` +{{define "dbFile"}}// Code generated by sqlc. DO NOT EDIT. package {{.Package}} @@ -935,6 +981,10 @@ import ( {{end}} ) +{{template "dbCode" . }} +{{end}} + +{{define "dbCode"}} type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) @@ -1029,9 +1079,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } -` +{{end}} -var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT. +{{define "interfaceFile"}}// Code generated by sqlc. DO NOT EDIT. package {{.Package}} @@ -1042,6 +1092,10 @@ import ( {{end}} ) +{{template "interfaceCode" . }} +{{end}} + +{{define "interfaceCode"}} type Querier interface { {{- range .GoQueries}} {{- if eq .Cmd ":one"}} @@ -1060,9 +1114,9 @@ type Querier interface { } var _ Querier = (*Queries)(nil) -` +{{end}} -var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT. +{{define "modelsFile"}}// Code generated by sqlc. DO NOT EDIT. package {{.Package}} @@ -1073,6 +1127,10 @@ import ( {{end}} ) +{{template "modelsCode" . }} +{{end}} + +{{define "modelsCode"}} {{range .Enums}} {{if .Comment}}{{comment .Comment}}{{end}} type {{.Name}} string @@ -1099,9 +1157,9 @@ type {{.Name}} struct { {{- range .Fields}} {{- end}} } {{end}} -` +{{end}} -var sqlTmpl = `// Code generated by sqlc. DO NOT EDIT. +{{define "queryFile"}}// Code generated by sqlc. DO NOT EDIT. // source: {{.SourceName}} package {{.Package}} @@ -1113,8 +1171,12 @@ import ( {{end}} ) +{{template "queryCode" . }} +{{end}} + +{{define "queryCode"}} {{range .GoQueries}} -{{if eq .SourceName $.SourceName}} +{{if $.OutputQuery .SourceName}} const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{.SQL}} {{$.Q}} @@ -1209,6 +1271,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er {{end}} {{end}} {{end}} +{{end}} + +{{define "singleFile"}}// Code generated by sqlc. DO NOT EDIT. + +package {{.Package}} + +import ( + {{range imports "all.go"}} + {{range .}}"{{.}}" + {{end}} + {{end}} +) + +{{template "modelsCode" . }} + +{{template "queryCode" . }} + +{{template "dbCode" . }} + +{{template "interfaceCode" . }} +{{end}} ` type tmplCtx struct { @@ -1225,6 +1308,11 @@ type tmplCtx struct { EmitJSONTags bool EmitPreparedQueries bool EmitInterface bool + EmitSingleFile bool +} + +func (t *tmplCtx) OutputQuery(sourceName string) bool { + return t.SourceName == sourceName || t.EmitSingleFile } func LowerTitle(s string) string { @@ -1244,10 +1332,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri "imports": Imports(r, settings), } - dbFile := template.Must(template.New("table").Funcs(funcMap).Parse(dbTmpl)) - modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl)) - sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(sqlTmpl)) - ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ifaceTmpl)) + tmpl := template.Must(template.New("table").Funcs(funcMap).Parse(templateSet)) golang := settings.Go tctx := tmplCtx{ @@ -1255,6 +1340,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri EmitInterface: golang.EmitInterface, EmitJSONTags: golang.EmitJSONTags, EmitPreparedQueries: golang.EmitPreparedQueries, + EmitSingleFile: golang.EmitSingleFile, Q: "`", Package: golang.Package, GoQueries: r.GoQueries(settings), @@ -1264,11 +1350,11 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri output := map[string]string{} - execute := func(name string, t *template.Template) error { + execute := func(name, templateName string) error { var b bytes.Buffer w := bufio.NewWriter(&b) tctx.SourceName = name - err := t.Execute(w, tctx) + err := tmpl.ExecuteTemplate(w, templateName, &tctx) w.Flush() if err != nil { return err @@ -1285,14 +1371,22 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri return nil } - if err := execute("db.go", dbFile); err != nil { + // Output a single file with all code + if golang.EmitSingleFile { + if err := execute("db.go", "singleFile"); err != nil { + return nil, err + } + return output, nil + } + + if err := execute("db.go", "dbFile"); err != nil { return nil, err } - if err := execute("models.go", modelsFile); err != nil { + if err := execute("models.go", "modelsFile"); err != nil { return nil, err } if golang.EmitInterface { - if err := execute("querier.go", ifaceFile); err != nil { + if err := execute("querier.go", "interfaceFile"); err != nil { return nil, err } } @@ -1303,7 +1397,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri } for source := range files { - if err := execute(source, sqlFile); err != nil { + if err := execute(source, "queryFile"); err != nil { return nil, err } } diff --git a/internal/endtoend/testdata/ondeck_single_file/db.go b/internal/endtoend/testdata/ondeck_single_file/db.go new file mode 100644 index 0000000000..dda71733e4 --- /dev/null +++ b/internal/endtoend/testdata/ondeck_single_file/db.go @@ -0,0 +1,316 @@ +// Code generated by sqlc. DO NOT EDIT. + +package ondeck + +import ( + "context" + "database/sql" + "time" +) + +type Status string + +const ( + StatusOpen Status = "open" + StatusClosed Status = "closed" +) + +func (e *Status) Scan(src interface{}) error { + *e = Status(src.([]byte)) + return nil +} + +type City struct { + Slug string + Name string +} + +type Venue struct { + ID int32 + CreateAt time.Time + Status Status + Slug string + Name string + City string + SpotifyPlaylist string + SongkickID sql.NullString +} + +const createCity = `-- name: CreateCity :one +INSERT INTO city ( + name, + slug +) VALUES ( + $1, + $2 +) RETURNING slug, name +` + +type CreateCityParams struct { + Name string + Slug string +} + +func (q *Queries) CreateCity(ctx context.Context, arg CreateCityParams) (City, error) { + row := q.db.QueryRowContext(ctx, createCity, arg.Name, arg.Slug) + var i City + err := row.Scan(&i.Slug, &i.Name) + return i, err +} + +const createVenue = `-- name: CreateVenue :one +INSERT INTO venue ( + slug, + name, + city, + created_at, + spotify_playlist, + status +) VALUES ( + $1, + $2, + $3, + NOW(), + $4, + $5 +) RETURNING id +` + +type CreateVenueParams struct { + Slug string + Name string + City string + SpotifyPlaylist string + Status Status +} + +func (q *Queries) CreateVenue(ctx context.Context, arg CreateVenueParams) (int32, error) { + row := q.db.QueryRowContext(ctx, createVenue, + arg.Slug, + arg.Name, + arg.City, + arg.SpotifyPlaylist, + arg.Status, + ) + var id int32 + err := row.Scan(&id) + return id, err +} + +const deleteVenue = `-- name: DeleteVenue :exec +DELETE FROM venue +WHERE slug = $1 AND slug = $1 +` + +func (q *Queries) DeleteVenue(ctx context.Context, slug string) error { + _, err := q.db.ExecContext(ctx, deleteVenue, slug) + return err +} + +const getCity = `-- name: GetCity :one +SELECT slug, name FROM city WHERE slug = $1 +` + +func (q *Queries) GetCity(ctx context.Context, slug string) (City, error) { + row := q.db.QueryRowContext(ctx, getCity, slug) + var i City + err := row.Scan(&i.Slug, &i.Name) + return i, err +} + +const getVenue = `-- name: GetVenue :one +SELECT id, create_at, status, slug, name, city, spotify_playlist, songkick_id +FROM venue +WHERE slug = $1 AND city = $2 +` + +type GetVenueParams struct { + Slug string + City string +} + +func (q *Queries) GetVenue(ctx context.Context, arg GetVenueParams) (Venue, error) { + row := q.db.QueryRowContext(ctx, getVenue, arg.Slug, arg.City) + var i Venue + err := row.Scan( + &i.ID, + &i.CreateAt, + &i.Status, + &i.Slug, + &i.Name, + &i.City, + &i.SpotifyPlaylist, + &i.SongkickID, + ) + return i, err +} + +const listCityByName = `-- name: ListCityByName :many +SELECT slug, name FROM city ORDER BY name +` + +func (q *Queries) ListCityByName(ctx context.Context) ([]City, error) { + rows, err := q.db.QueryContext(ctx, listCityByName) + if err != nil { + return nil, err + } + defer rows.Close() + var items []City + for rows.Next() { + var i City + if err := rows.Scan(&i.Slug, &i.Name); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listVenues = `-- name: ListVenues :many +SELECT id, create_at, status, slug, name, city, spotify_playlist, songkick_id +FROM venue +WHERE city = $1 +ORDER BY name +` + +func (q *Queries) ListVenues(ctx context.Context, city string) ([]Venue, error) { + rows, err := q.db.QueryContext(ctx, listVenues, city) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Venue + for rows.Next() { + var i Venue + if err := rows.Scan( + &i.ID, + &i.CreateAt, + &i.Status, + &i.Slug, + &i.Name, + &i.City, + &i.SpotifyPlaylist, + &i.SongkickID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateCity = `-- name: UpdateCity :exec +UPDATE city SET name = $2 WHERE slug = $1 +` + +type UpdateCityParams struct { + Slug string + Name string +} + +func (q *Queries) UpdateCity(ctx context.Context, arg UpdateCityParams) error { + _, err := q.db.ExecContext(ctx, updateCity, arg.Slug, arg.Name) + return err +} + +const updateVenueName = `-- name: UpdateVenueName :one +UPDATE venue +SET name = $2 +WHERE slug = $1 +RETURNING id +` + +type UpdateVenueNameParams struct { + Slug string + Name string +} + +func (q *Queries) UpdateVenueName(ctx context.Context, arg UpdateVenueNameParams) (int32, error) { + row := q.db.QueryRowContext(ctx, updateVenueName, arg.Slug, arg.Name) + var id int32 + err := row.Scan(&id) + return id, err +} + +const venueCountByCity = `-- name: VenueCountByCity :many +SELECT city, count(*) +FROM venue +GROUP BY 1 +ORDER BY 1 +` + +type VenueCountByCityRow struct { + City string + Count int64 +} + +func (q *Queries) VenueCountByCity(ctx context.Context) ([]VenueCountByCityRow, error) { + rows, err := q.db.QueryContext(ctx, venueCountByCity) + if err != nil { + return nil, err + } + defer rows.Close() + var items []VenueCountByCityRow + for rows.Next() { + var i VenueCountByCityRow + if err := rows.Scan(&i.City, &i.Count); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} + +type Querier interface { + CreateCity(ctx context.Context, arg CreateCityParams) (City, error) + CreateVenue(ctx context.Context, arg CreateVenueParams) (int32, error) + DeleteVenue(ctx context.Context, slug string) error + GetCity(ctx context.Context, slug string) (City, error) + GetVenue(ctx context.Context, arg GetVenueParams) (Venue, error) + ListCityByName(ctx context.Context) ([]City, error) + ListVenues(ctx context.Context, city string) ([]Venue, error) + UpdateCity(ctx context.Context, arg UpdateCityParams) error + UpdateVenueName(ctx context.Context, arg UpdateVenueNameParams) (int32, error) + VenueCountByCity(ctx context.Context) ([]VenueCountByCityRow, error) +} + +var _ Querier = (*Queries)(nil) diff --git a/internal/endtoend/testdata/ondeck_single_file/query.sql b/internal/endtoend/testdata/ondeck_single_file/query.sql new file mode 100644 index 0000000000..86193a205c --- /dev/null +++ b/internal/endtoend/testdata/ondeck_single_file/query.sql @@ -0,0 +1,62 @@ +-- name: ListCityByName :many +SELECT * FROM city ORDER BY name; + +-- name: GetCity :one +SELECT * FROM city WHERE slug = $1; + +-- name: CreateCity :one +INSERT INTO city ( + name, + slug +) VALUES ( + $1, + $2 +) RETURNING *; + +-- name: UpdateCity :exec +UPDATE city SET name = $2 WHERE slug = $1; + +-- name: ListVenues :many +SELECT * +FROM venue +WHERE city = $1 +ORDER BY name; + +-- name: DeleteVenue :exec +DELETE FROM venue +WHERE slug = $1 AND slug = $1; + +-- name: GetVenue :one +SELECT * +FROM venue +WHERE slug = $1 AND city = $2; + +-- name: CreateVenue :one +INSERT INTO venue ( + slug, + name, + city, + created_at, + spotify_playlist, + status +) VALUES ( + $1, + $2, + $3, + NOW(), + $4, + $5 +) RETURNING id; + + +-- name: UpdateVenueName :one +UPDATE venue +SET name = $2 +WHERE slug = $1 +RETURNING id; + +-- name: VenueCountByCity :many +SELECT city, count(*) +FROM venue +GROUP BY 1 +ORDER BY 1; diff --git a/internal/endtoend/testdata/ondeck_single_file/schema.sql b/internal/endtoend/testdata/ondeck_single_file/schema.sql new file mode 100644 index 0000000000..62768b9e9a --- /dev/null +++ b/internal/endtoend/testdata/ondeck_single_file/schema.sql @@ -0,0 +1,17 @@ +CREATE TABLE city ( + slug text PRIMARY KEY, + name text NOT NULL +); + +CREATE TYPE status AS ENUM ('open', 'closed'); + +CREATE TABLE venue ( + id SERIAL primary key, + create_at timestamp not null, + status status not null, + slug text not null, + name varchar(255) not null, + city text not null references city(slug), + spotify_playlist varchar not null, + songkick_id text +); diff --git a/internal/endtoend/testdata/ondeck_single_file/sqlc.json b/internal/endtoend/testdata/ondeck_single_file/sqlc.json new file mode 100644 index 0000000000..4b1a4f8b44 --- /dev/null +++ b/internal/endtoend/testdata/ondeck_single_file/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": ".", + "name": "ondeck", + "schema": "schema.sql", + "queries": "query.sql", + "emit_single_file": true + } + ] +}