Skip to content

gen: Add option to emit single file for Go #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type SQLGo struct {
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries":`
EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file":`
Package string `json:"package" yaml:"package"`
Out string `json:"out" yaml:"out"`
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
Expand Down
2 changes: 2 additions & 0 deletions internal/config/v_one.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type v1PackageSettings struct {
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"`
EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file"`
Overrides []Override `json:"overrides" yaml:"overrides"`
}

Expand Down Expand Up @@ -103,6 +104,7 @@ func (c *V1GenerateSettings) Translate() Config {
EmitInterface: pkg.EmitInterface,
EmitJSONTags: pkg.EmitJSONTags,
EmitPreparedQueries: pkg.EmitPreparedQueries,
EmitSingleFile: pkg.EmitSingleFile,
Package: pkg.Name,
Out: pkg.Path,
Overrides: pkg.Overrides,
Expand Down
160 changes: 127 additions & 33 deletions internal/dinosql/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,29 +188,74 @@ func UsesArrays(r Generateable, settings config.CombinedSettings) bool {
return false
}

type fileImports struct {
Std []string
Dep []string
}

func mergeImports(imps ...fileImports) [][]string {
if len(imps) == 1 {
return [][]string{imps[0].Std, imps[0].Dep}
}

var stds, pkgs []string
seenStd := map[string]struct{}{}
seenPkg := map[string]struct{}{}
for i := range imps {
for _, std := range imps[i].Std {
if _, ok := seenStd[std]; ok {
continue
}
stds = append(stds, std)
seenStd[std] = struct{}{}
}
for _, pkg := range imps[i].Dep {
if _, ok := seenPkg[pkg]; ok {
continue
}
pkgs = append(pkgs, pkg)
seenPkg[pkg] = struct{}{}
}
}
return [][]string{stds, pkgs}
}

func Imports(r Generateable, settings config.CombinedSettings) func(string) [][]string {
return func(filename string) [][]string {
if filename == "all.go" {
var imps []fileImports
imps = append(imps, dbImports(r, settings))
imps = append(imps, modelImports(r, settings))
imps = append(imps, interfaceImports(r, settings))
imps = append(imps, queryImports(r, settings, filename))
return mergeImports(imps...)
}

if filename == "db.go" {
imps := []string{"context", "database/sql"}
if settings.Go.EmitPreparedQueries {
imps = append(imps, "fmt")
}
return [][]string{imps}
return mergeImports(dbImports(r, settings))
}

if filename == "models.go" {
return ModelImports(r, settings)
return mergeImports(modelImports(r, settings))
}

if filename == "querier.go" {
return InterfaceImports(r, settings)
return mergeImports(interfaceImports(r, settings))
}

return QueryImports(r, settings, filename)
return mergeImports(queryImports(r, settings, filename))
}
}

func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]string {
func dbImports(r Generateable, settings config.CombinedSettings) fileImports {
std := []string{"context", "database/sql"}
if settings.Go.EmitPreparedQueries {
std = append(std, "fmt")
}
return fileImports{Std: std}
}

func interfaceImports(r Generateable, settings config.CombinedSettings) fileImports {
gq := r.GoQueries(settings)
uses := func(name string) bool {
for _, q := range gq {
Expand Down Expand Up @@ -284,10 +329,10 @@ func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]stri

sort.Strings(stds)
sort.Strings(pkgs)
return [][]string{stds, pkgs}
return fileImports{stds, pkgs}
}

func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
func modelImports(r Generateable, settings config.CombinedSettings) fileImports {
std := make(map[string]struct{})
if UsesType(r, "sql.Null", settings) {
std["database/sql"] = struct{}{}
Expand Down Expand Up @@ -343,10 +388,10 @@ func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {

sort.Strings(stds)
sort.Strings(pkgs)
return [][]string{stds, pkgs}
return fileImports{stds, pkgs}
}

func QueryImports(r Generateable, settings config.CombinedSettings, filename string) [][]string {
func queryImports(r Generateable, settings config.CombinedSettings, filename string) fileImports {
// for _, strct := range r.Structs() {
// for _, f := range strct.Fields {
// if strings.HasPrefix(f.Type, "[]") {
Expand All @@ -356,7 +401,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
// }
var gq []GoQuery
for _, query := range r.GoQueries(settings) {
if query.SourceName == filename {
if query.SourceName == filename || settings.Go.EmitSingleFile {
gq = append(gq, query)
}
}
Expand Down Expand Up @@ -481,7 +526,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str

sort.Strings(stds)
sort.Strings(pkgs)
return [][]string{stds, pkgs}
return fileImports{stds, pkgs}
}

func enumValueName(value string) string {
Expand Down Expand Up @@ -924,7 +969,8 @@ func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery {
return qs
}

var dbTmpl = `// Code generated by sqlc. DO NOT EDIT.
var templateSet = `
{{define "dbFile"}}// Code generated by sqlc. DO NOT EDIT.

package {{.Package}}

Expand All @@ -935,6 +981,10 @@ import (
{{end}}
)

{{template "dbCode" . }}
{{end}}

{{define "dbCode"}}
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
Expand Down Expand Up @@ -1029,9 +1079,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
{{- end}}
}
}
`
{{end}}

var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
{{define "interfaceFile"}}// Code generated by sqlc. DO NOT EDIT.

package {{.Package}}

Expand All @@ -1042,6 +1092,10 @@ import (
{{end}}
)

{{template "interfaceCode" . }}
{{end}}

{{define "interfaceCode"}}
type Querier interface {
{{- range .GoQueries}}
{{- if eq .Cmd ":one"}}
Expand All @@ -1060,9 +1114,9 @@ type Querier interface {
}

var _ Querier = (*Queries)(nil)
`
{{end}}

var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT.
{{define "modelsFile"}}// Code generated by sqlc. DO NOT EDIT.

package {{.Package}}

Expand All @@ -1073,6 +1127,10 @@ import (
{{end}}
)

{{template "modelsCode" . }}
{{end}}

{{define "modelsCode"}}
{{range .Enums}}
{{if .Comment}}{{comment .Comment}}{{end}}
type {{.Name}} string
Expand All @@ -1099,9 +1157,9 @@ type {{.Name}} struct { {{- range .Fields}}
{{- end}}
}
{{end}}
`
{{end}}

var sqlTmpl = `// Code generated by sqlc. DO NOT EDIT.
{{define "queryFile"}}// Code generated by sqlc. DO NOT EDIT.
// source: {{.SourceName}}

package {{.Package}}
Expand All @@ -1113,8 +1171,12 @@ import (
{{end}}
)

{{template "queryCode" . }}
{{end}}

{{define "queryCode"}}
{{range .GoQueries}}
{{if eq .SourceName $.SourceName}}
{{if $.OutputQuery .SourceName}}
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
{{.SQL}}
{{$.Q}}
Expand Down Expand Up @@ -1209,6 +1271,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
{{end}}
{{end}}
{{end}}
{{end}}

{{define "singleFile"}}// Code generated by sqlc. DO NOT EDIT.

package {{.Package}}

import (
{{range imports "all.go"}}
{{range .}}"{{.}}"
{{end}}
{{end}}
)

{{template "modelsCode" . }}

{{template "queryCode" . }}

{{template "dbCode" . }}

{{template "interfaceCode" . }}
{{end}}
`

type tmplCtx struct {
Expand All @@ -1225,6 +1308,11 @@ type tmplCtx struct {
EmitJSONTags bool
EmitPreparedQueries bool
EmitInterface bool
EmitSingleFile bool
}

func (t *tmplCtx) OutputQuery(sourceName string) bool {
return t.SourceName == sourceName || t.EmitSingleFile
}

func LowerTitle(s string) string {
Expand All @@ -1244,17 +1332,15 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
"imports": Imports(r, settings),
}

dbFile := template.Must(template.New("table").Funcs(funcMap).Parse(dbTmpl))
modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl))
sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(sqlTmpl))
ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ifaceTmpl))
tmpl := template.Must(template.New("table").Funcs(funcMap).Parse(templateSet))

golang := settings.Go
tctx := tmplCtx{
Settings: settings.Global,
EmitInterface: golang.EmitInterface,
EmitJSONTags: golang.EmitJSONTags,
EmitPreparedQueries: golang.EmitPreparedQueries,
EmitSingleFile: golang.EmitSingleFile,
Q: "`",
Package: golang.Package,
GoQueries: r.GoQueries(settings),
Expand All @@ -1264,11 +1350,11 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri

output := map[string]string{}

execute := func(name string, t *template.Template) error {
execute := func(name, templateName string) error {
var b bytes.Buffer
w := bufio.NewWriter(&b)
tctx.SourceName = name
err := t.Execute(w, tctx)
err := tmpl.ExecuteTemplate(w, templateName, &tctx)
w.Flush()
if err != nil {
return err
Expand All @@ -1285,14 +1371,22 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
return nil
}

if err := execute("db.go", dbFile); err != nil {
// Output a single file with all code
if golang.EmitSingleFile {
if err := execute("db.go", "singleFile"); err != nil {
return nil, err
}
return output, nil
}

if err := execute("db.go", "dbFile"); err != nil {
return nil, err
}
if err := execute("models.go", modelsFile); err != nil {
if err := execute("models.go", "modelsFile"); err != nil {
return nil, err
}
if golang.EmitInterface {
if err := execute("querier.go", ifaceFile); err != nil {
if err := execute("querier.go", "interfaceFile"); err != nil {
return nil, err
}
}
Expand All @@ -1303,7 +1397,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
}

for source := range files {
if err := execute(source, sqlFile); err != nil {
if err := execute(source, "queryFile"); err != nil {
return nil, err
}
}
Expand Down
Loading