Skip to content

Commit 3978046

Browse files
authored
gen: Add option to emit single file for Go (#366)
* gen: Add option to emit single file for Go This option will be used for the interactive playground. * Use EmitSignleFile
1 parent 582a575 commit 3978046

File tree

7 files changed

+537
-33
lines changed

7 files changed

+537
-33
lines changed

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ type SQLGo struct {
8181
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
8282
EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
8383
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries":`
84+
EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file":`
8485
Package string `json:"package" yaml:"package"`
8586
Out string `json:"out" yaml:"out"`
8687
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`

internal/config/v_one.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type v1PackageSettings struct {
2424
EmitInterface bool `json:"emit_interface" yaml:"emit_interface"`
2525
EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"`
2626
EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"`
27+
EmitSingleFile bool `json:"emit_single_file" yaml:"emit_single_file"`
2728
Overrides []Override `json:"overrides" yaml:"overrides"`
2829
}
2930

@@ -103,6 +104,7 @@ func (c *V1GenerateSettings) Translate() Config {
103104
EmitInterface: pkg.EmitInterface,
104105
EmitJSONTags: pkg.EmitJSONTags,
105106
EmitPreparedQueries: pkg.EmitPreparedQueries,
107+
EmitSingleFile: pkg.EmitSingleFile,
106108
Package: pkg.Name,
107109
Out: pkg.Path,
108110
Overrides: pkg.Overrides,

internal/dinosql/gen.go

Lines changed: 127 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -188,29 +188,74 @@ func UsesArrays(r Generateable, settings config.CombinedSettings) bool {
188188
return false
189189
}
190190

191+
type fileImports struct {
192+
Std []string
193+
Dep []string
194+
}
195+
196+
func mergeImports(imps ...fileImports) [][]string {
197+
if len(imps) == 1 {
198+
return [][]string{imps[0].Std, imps[0].Dep}
199+
}
200+
201+
var stds, pkgs []string
202+
seenStd := map[string]struct{}{}
203+
seenPkg := map[string]struct{}{}
204+
for i := range imps {
205+
for _, std := range imps[i].Std {
206+
if _, ok := seenStd[std]; ok {
207+
continue
208+
}
209+
stds = append(stds, std)
210+
seenStd[std] = struct{}{}
211+
}
212+
for _, pkg := range imps[i].Dep {
213+
if _, ok := seenPkg[pkg]; ok {
214+
continue
215+
}
216+
pkgs = append(pkgs, pkg)
217+
seenPkg[pkg] = struct{}{}
218+
}
219+
}
220+
return [][]string{stds, pkgs}
221+
}
222+
191223
func Imports(r Generateable, settings config.CombinedSettings) func(string) [][]string {
192224
return func(filename string) [][]string {
225+
if filename == "all.go" {
226+
var imps []fileImports
227+
imps = append(imps, dbImports(r, settings))
228+
imps = append(imps, modelImports(r, settings))
229+
imps = append(imps, interfaceImports(r, settings))
230+
imps = append(imps, queryImports(r, settings, filename))
231+
return mergeImports(imps...)
232+
}
233+
193234
if filename == "db.go" {
194-
imps := []string{"context", "database/sql"}
195-
if settings.Go.EmitPreparedQueries {
196-
imps = append(imps, "fmt")
197-
}
198-
return [][]string{imps}
235+
return mergeImports(dbImports(r, settings))
199236
}
200237

201238
if filename == "models.go" {
202-
return ModelImports(r, settings)
239+
return mergeImports(modelImports(r, settings))
203240
}
204241

205242
if filename == "querier.go" {
206-
return InterfaceImports(r, settings)
243+
return mergeImports(interfaceImports(r, settings))
207244
}
208245

209-
return QueryImports(r, settings, filename)
246+
return mergeImports(queryImports(r, settings, filename))
210247
}
211248
}
212249

213-
func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]string {
250+
func dbImports(r Generateable, settings config.CombinedSettings) fileImports {
251+
std := []string{"context", "database/sql"}
252+
if settings.Go.EmitPreparedQueries {
253+
std = append(std, "fmt")
254+
}
255+
return fileImports{Std: std}
256+
}
257+
258+
func interfaceImports(r Generateable, settings config.CombinedSettings) fileImports {
214259
gq := r.GoQueries(settings)
215260
uses := func(name string) bool {
216261
for _, q := range gq {
@@ -284,10 +329,10 @@ func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]stri
284329

285330
sort.Strings(stds)
286331
sort.Strings(pkgs)
287-
return [][]string{stds, pkgs}
332+
return fileImports{stds, pkgs}
288333
}
289334

290-
func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
335+
func modelImports(r Generateable, settings config.CombinedSettings) fileImports {
291336
std := make(map[string]struct{})
292337
if UsesType(r, "sql.Null", settings) {
293338
std["database/sql"] = struct{}{}
@@ -343,10 +388,10 @@ func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
343388

344389
sort.Strings(stds)
345390
sort.Strings(pkgs)
346-
return [][]string{stds, pkgs}
391+
return fileImports{stds, pkgs}
347392
}
348393

349-
func QueryImports(r Generateable, settings config.CombinedSettings, filename string) [][]string {
394+
func queryImports(r Generateable, settings config.CombinedSettings, filename string) fileImports {
350395
// for _, strct := range r.Structs() {
351396
// for _, f := range strct.Fields {
352397
// if strings.HasPrefix(f.Type, "[]") {
@@ -356,7 +401,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
356401
// }
357402
var gq []GoQuery
358403
for _, query := range r.GoQueries(settings) {
359-
if query.SourceName == filename {
404+
if query.SourceName == filename || settings.Go.EmitSingleFile {
360405
gq = append(gq, query)
361406
}
362407
}
@@ -481,7 +526,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
481526

482527
sort.Strings(stds)
483528
sort.Strings(pkgs)
484-
return [][]string{stds, pkgs}
529+
return fileImports{stds, pkgs}
485530
}
486531

487532
func enumValueName(value string) string {
@@ -924,7 +969,8 @@ func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery {
924969
return qs
925970
}
926971

927-
var dbTmpl = `// Code generated by sqlc. DO NOT EDIT.
972+
var templateSet = `
973+
{{define "dbFile"}}// Code generated by sqlc. DO NOT EDIT.
928974
929975
package {{.Package}}
930976
@@ -935,6 +981,10 @@ import (
935981
{{end}}
936982
)
937983
984+
{{template "dbCode" . }}
985+
{{end}}
986+
987+
{{define "dbCode"}}
938988
type DBTX interface {
939989
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
940990
PrepareContext(context.Context, string) (*sql.Stmt, error)
@@ -1029,9 +1079,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
10291079
{{- end}}
10301080
}
10311081
}
1032-
`
1082+
{{end}}
10331083
1034-
var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
1084+
{{define "interfaceFile"}}// Code generated by sqlc. DO NOT EDIT.
10351085
10361086
package {{.Package}}
10371087
@@ -1042,6 +1092,10 @@ import (
10421092
{{end}}
10431093
)
10441094
1095+
{{template "interfaceCode" . }}
1096+
{{end}}
1097+
1098+
{{define "interfaceCode"}}
10451099
type Querier interface {
10461100
{{- range .GoQueries}}
10471101
{{- if eq .Cmd ":one"}}
@@ -1060,9 +1114,9 @@ type Querier interface {
10601114
}
10611115
10621116
var _ Querier = (*Queries)(nil)
1063-
`
1117+
{{end}}
10641118
1065-
var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT.
1119+
{{define "modelsFile"}}// Code generated by sqlc. DO NOT EDIT.
10661120
10671121
package {{.Package}}
10681122
@@ -1073,6 +1127,10 @@ import (
10731127
{{end}}
10741128
)
10751129
1130+
{{template "modelsCode" . }}
1131+
{{end}}
1132+
1133+
{{define "modelsCode"}}
10761134
{{range .Enums}}
10771135
{{if .Comment}}{{comment .Comment}}{{end}}
10781136
type {{.Name}} string
@@ -1099,9 +1157,9 @@ type {{.Name}} struct { {{- range .Fields}}
10991157
{{- end}}
11001158
}
11011159
{{end}}
1102-
`
1160+
{{end}}
11031161
1104-
var sqlTmpl = `// Code generated by sqlc. DO NOT EDIT.
1162+
{{define "queryFile"}}// Code generated by sqlc. DO NOT EDIT.
11051163
// source: {{.SourceName}}
11061164
11071165
package {{.Package}}
@@ -1113,8 +1171,12 @@ import (
11131171
{{end}}
11141172
)
11151173
1174+
{{template "queryCode" . }}
1175+
{{end}}
1176+
1177+
{{define "queryCode"}}
11161178
{{range .GoQueries}}
1117-
{{if eq .SourceName $.SourceName}}
1179+
{{if $.OutputQuery .SourceName}}
11181180
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
11191181
{{.SQL}}
11201182
{{$.Q}}
@@ -1209,6 +1271,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
12091271
{{end}}
12101272
{{end}}
12111273
{{end}}
1274+
{{end}}
1275+
1276+
{{define "singleFile"}}// Code generated by sqlc. DO NOT EDIT.
1277+
1278+
package {{.Package}}
1279+
1280+
import (
1281+
{{range imports "all.go"}}
1282+
{{range .}}"{{.}}"
1283+
{{end}}
1284+
{{end}}
1285+
)
1286+
1287+
{{template "modelsCode" . }}
1288+
1289+
{{template "queryCode" . }}
1290+
1291+
{{template "dbCode" . }}
1292+
1293+
{{template "interfaceCode" . }}
1294+
{{end}}
12121295
`
12131296

12141297
type tmplCtx struct {
@@ -1225,6 +1308,11 @@ type tmplCtx struct {
12251308
EmitJSONTags bool
12261309
EmitPreparedQueries bool
12271310
EmitInterface bool
1311+
EmitSingleFile bool
1312+
}
1313+
1314+
func (t *tmplCtx) OutputQuery(sourceName string) bool {
1315+
return t.SourceName == sourceName || t.EmitSingleFile
12281316
}
12291317

12301318
func LowerTitle(s string) string {
@@ -1244,17 +1332,15 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12441332
"imports": Imports(r, settings),
12451333
}
12461334

1247-
dbFile := template.Must(template.New("table").Funcs(funcMap).Parse(dbTmpl))
1248-
modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl))
1249-
sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(sqlTmpl))
1250-
ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ifaceTmpl))
1335+
tmpl := template.Must(template.New("table").Funcs(funcMap).Parse(templateSet))
12511336

12521337
golang := settings.Go
12531338
tctx := tmplCtx{
12541339
Settings: settings.Global,
12551340
EmitInterface: golang.EmitInterface,
12561341
EmitJSONTags: golang.EmitJSONTags,
12571342
EmitPreparedQueries: golang.EmitPreparedQueries,
1343+
EmitSingleFile: golang.EmitSingleFile,
12581344
Q: "`",
12591345
Package: golang.Package,
12601346
GoQueries: r.GoQueries(settings),
@@ -1264,11 +1350,11 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12641350

12651351
output := map[string]string{}
12661352

1267-
execute := func(name string, t *template.Template) error {
1353+
execute := func(name, templateName string) error {
12681354
var b bytes.Buffer
12691355
w := bufio.NewWriter(&b)
12701356
tctx.SourceName = name
1271-
err := t.Execute(w, tctx)
1357+
err := tmpl.ExecuteTemplate(w, templateName, &tctx)
12721358
w.Flush()
12731359
if err != nil {
12741360
return err
@@ -1285,14 +1371,22 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12851371
return nil
12861372
}
12871373

1288-
if err := execute("db.go", dbFile); err != nil {
1374+
// Output a single file with all code
1375+
if golang.EmitSingleFile {
1376+
if err := execute("db.go", "singleFile"); err != nil {
1377+
return nil, err
1378+
}
1379+
return output, nil
1380+
}
1381+
1382+
if err := execute("db.go", "dbFile"); err != nil {
12891383
return nil, err
12901384
}
1291-
if err := execute("models.go", modelsFile); err != nil {
1385+
if err := execute("models.go", "modelsFile"); err != nil {
12921386
return nil, err
12931387
}
12941388
if golang.EmitInterface {
1295-
if err := execute("querier.go", ifaceFile); err != nil {
1389+
if err := execute("querier.go", "interfaceFile"); err != nil {
12961390
return nil, err
12971391
}
12981392
}
@@ -1303,7 +1397,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
13031397
}
13041398

13051399
for source := range files {
1306-
if err := execute(source, sqlFile); err != nil {
1400+
if err := execute(source, "queryFile"); err != nil {
13071401
return nil, err
13081402
}
13091403
}

0 commit comments

Comments
 (0)