@@ -188,29 +188,74 @@ func UsesArrays(r Generateable, settings config.CombinedSettings) bool {
188
188
return false
189
189
}
190
190
191
+ type fileImports struct {
192
+ Std []string
193
+ Dep []string
194
+ }
195
+
196
+ func mergeImports (imps ... fileImports ) [][]string {
197
+ if len (imps ) == 1 {
198
+ return [][]string {imps [0 ].Std , imps [0 ].Dep }
199
+ }
200
+
201
+ var stds , pkgs []string
202
+ seenStd := map [string ]struct {}{}
203
+ seenPkg := map [string ]struct {}{}
204
+ for i := range imps {
205
+ for _ , std := range imps [i ].Std {
206
+ if _ , ok := seenStd [std ]; ok {
207
+ continue
208
+ }
209
+ stds = append (stds , std )
210
+ seenStd [std ] = struct {}{}
211
+ }
212
+ for _ , pkg := range imps [i ].Dep {
213
+ if _ , ok := seenPkg [pkg ]; ok {
214
+ continue
215
+ }
216
+ pkgs = append (pkgs , pkg )
217
+ seenPkg [pkg ] = struct {}{}
218
+ }
219
+ }
220
+ return [][]string {stds , pkgs }
221
+ }
222
+
191
223
func Imports (r Generateable , settings config.CombinedSettings ) func (string ) [][]string {
192
224
return func (filename string ) [][]string {
225
+ if filename == "all.go" {
226
+ var imps []fileImports
227
+ imps = append (imps , dbImports (r , settings ))
228
+ imps = append (imps , modelImports (r , settings ))
229
+ imps = append (imps , interfaceImports (r , settings ))
230
+ imps = append (imps , queryImports (r , settings , filename ))
231
+ return mergeImports (imps ... )
232
+ }
233
+
193
234
if filename == "db.go" {
194
- imps := []string {"context" , "database/sql" }
195
- if settings .Go .EmitPreparedQueries {
196
- imps = append (imps , "fmt" )
197
- }
198
- return [][]string {imps }
235
+ return mergeImports (dbImports (r , settings ))
199
236
}
200
237
201
238
if filename == "models.go" {
202
- return ModelImports ( r , settings )
239
+ return mergeImports ( modelImports ( r , settings ) )
203
240
}
204
241
205
242
if filename == "querier.go" {
206
- return InterfaceImports ( r , settings )
243
+ return mergeImports ( interfaceImports ( r , settings ) )
207
244
}
208
245
209
- return QueryImports ( r , settings , filename )
246
+ return mergeImports ( queryImports ( r , settings , filename ) )
210
247
}
211
248
}
212
249
213
- func InterfaceImports (r Generateable , settings config.CombinedSettings ) [][]string {
250
+ func dbImports (r Generateable , settings config.CombinedSettings ) fileImports {
251
+ std := []string {"context" , "database/sql" }
252
+ if settings .Go .EmitPreparedQueries {
253
+ std = append (std , "fmt" )
254
+ }
255
+ return fileImports {Std : std }
256
+ }
257
+
258
+ func interfaceImports (r Generateable , settings config.CombinedSettings ) fileImports {
214
259
gq := r .GoQueries (settings )
215
260
uses := func (name string ) bool {
216
261
for _ , q := range gq {
@@ -284,10 +329,10 @@ func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]stri
284
329
285
330
sort .Strings (stds )
286
331
sort .Strings (pkgs )
287
- return [][] string {stds , pkgs }
332
+ return fileImports {stds , pkgs }
288
333
}
289
334
290
- func ModelImports (r Generateable , settings config.CombinedSettings ) [][] string {
335
+ func modelImports (r Generateable , settings config.CombinedSettings ) fileImports {
291
336
std := make (map [string ]struct {})
292
337
if UsesType (r , "sql.Null" , settings ) {
293
338
std ["database/sql" ] = struct {}{}
@@ -343,10 +388,10 @@ func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
343
388
344
389
sort .Strings (stds )
345
390
sort .Strings (pkgs )
346
- return [][] string {stds , pkgs }
391
+ return fileImports {stds , pkgs }
347
392
}
348
393
349
- func QueryImports (r Generateable , settings config.CombinedSettings , filename string ) [][] string {
394
+ func queryImports (r Generateable , settings config.CombinedSettings , filename string ) fileImports {
350
395
// for _, strct := range r.Structs() {
351
396
// for _, f := range strct.Fields {
352
397
// if strings.HasPrefix(f.Type, "[]") {
@@ -356,7 +401,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
356
401
// }
357
402
var gq []GoQuery
358
403
for _ , query := range r .GoQueries (settings ) {
359
- if query .SourceName == filename {
404
+ if query .SourceName == filename || settings . Go . EmitSingleFile {
360
405
gq = append (gq , query )
361
406
}
362
407
}
@@ -481,7 +526,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
481
526
482
527
sort .Strings (stds )
483
528
sort .Strings (pkgs )
484
- return [][] string {stds , pkgs }
529
+ return fileImports {stds , pkgs }
485
530
}
486
531
487
532
func enumValueName (value string ) string {
@@ -924,7 +969,8 @@ func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery {
924
969
return qs
925
970
}
926
971
927
- var dbTmpl = `// Code generated by sqlc. DO NOT EDIT.
972
+ var templateSet = `
973
+ {{define "dbFile"}}// Code generated by sqlc. DO NOT EDIT.
928
974
929
975
package {{.Package}}
930
976
@@ -935,6 +981,10 @@ import (
935
981
{{end}}
936
982
)
937
983
984
+ {{template "dbCode" . }}
985
+ {{end}}
986
+
987
+ {{define "dbCode"}}
938
988
type DBTX interface {
939
989
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
940
990
PrepareContext(context.Context, string) (*sql.Stmt, error)
@@ -1029,9 +1079,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
1029
1079
{{- end}}
1030
1080
}
1031
1081
}
1032
- `
1082
+ {{end}}
1033
1083
1034
- var ifaceTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1084
+ {{define "interfaceFile"}} // Code generated by sqlc. DO NOT EDIT.
1035
1085
1036
1086
package {{.Package}}
1037
1087
@@ -1042,6 +1092,10 @@ import (
1042
1092
{{end}}
1043
1093
)
1044
1094
1095
+ {{template "interfaceCode" . }}
1096
+ {{end}}
1097
+
1098
+ {{define "interfaceCode"}}
1045
1099
type Querier interface {
1046
1100
{{- range .GoQueries}}
1047
1101
{{- if eq .Cmd ":one"}}
@@ -1060,9 +1114,9 @@ type Querier interface {
1060
1114
}
1061
1115
1062
1116
var _ Querier = (*Queries)(nil)
1063
- `
1117
+ {{end}}
1064
1118
1065
- var modelsTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1119
+ {{define "modelsFile"}} // Code generated by sqlc. DO NOT EDIT.
1066
1120
1067
1121
package {{.Package}}
1068
1122
@@ -1073,6 +1127,10 @@ import (
1073
1127
{{end}}
1074
1128
)
1075
1129
1130
+ {{template "modelsCode" . }}
1131
+ {{end}}
1132
+
1133
+ {{define "modelsCode"}}
1076
1134
{{range .Enums}}
1077
1135
{{if .Comment}}{{comment .Comment}}{{end}}
1078
1136
type {{.Name}} string
@@ -1099,9 +1157,9 @@ type {{.Name}} struct { {{- range .Fields}}
1099
1157
{{- end}}
1100
1158
}
1101
1159
{{end}}
1102
- `
1160
+ {{end}}
1103
1161
1104
- var sqlTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1162
+ {{define "queryFile"}} // Code generated by sqlc. DO NOT EDIT.
1105
1163
// source: {{.SourceName}}
1106
1164
1107
1165
package {{.Package}}
@@ -1113,8 +1171,12 @@ import (
1113
1171
{{end}}
1114
1172
)
1115
1173
1174
+ {{template "queryCode" . }}
1175
+ {{end}}
1176
+
1177
+ {{define "queryCode"}}
1116
1178
{{range .GoQueries}}
1117
- {{if eq .SourceName $ .SourceName}}
1179
+ {{if $.OutputQuery .SourceName}}
1118
1180
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
1119
1181
{{.SQL}}
1120
1182
{{$.Q}}
@@ -1209,6 +1271,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
1209
1271
{{end}}
1210
1272
{{end}}
1211
1273
{{end}}
1274
+ {{end}}
1275
+
1276
+ {{define "singleFile"}}// Code generated by sqlc. DO NOT EDIT.
1277
+
1278
+ package {{.Package}}
1279
+
1280
+ import (
1281
+ {{range imports "all.go"}}
1282
+ {{range .}}"{{.}}"
1283
+ {{end}}
1284
+ {{end}}
1285
+ )
1286
+
1287
+ {{template "modelsCode" . }}
1288
+
1289
+ {{template "queryCode" . }}
1290
+
1291
+ {{template "dbCode" . }}
1292
+
1293
+ {{template "interfaceCode" . }}
1294
+ {{end}}
1212
1295
`
1213
1296
1214
1297
type tmplCtx struct {
@@ -1225,6 +1308,11 @@ type tmplCtx struct {
1225
1308
EmitJSONTags bool
1226
1309
EmitPreparedQueries bool
1227
1310
EmitInterface bool
1311
+ EmitSingleFile bool
1312
+ }
1313
+
1314
+ func (t * tmplCtx ) OutputQuery (sourceName string ) bool {
1315
+ return t .SourceName == sourceName || t .EmitSingleFile
1228
1316
}
1229
1317
1230
1318
func LowerTitle (s string ) string {
@@ -1244,17 +1332,15 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
1244
1332
"imports" : Imports (r , settings ),
1245
1333
}
1246
1334
1247
- dbFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (dbTmpl ))
1248
- modelsFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (modelsTmpl ))
1249
- sqlFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (sqlTmpl ))
1250
- ifaceFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (ifaceTmpl ))
1335
+ tmpl := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (templateSet ))
1251
1336
1252
1337
golang := settings .Go
1253
1338
tctx := tmplCtx {
1254
1339
Settings : settings .Global ,
1255
1340
EmitInterface : golang .EmitInterface ,
1256
1341
EmitJSONTags : golang .EmitJSONTags ,
1257
1342
EmitPreparedQueries : golang .EmitPreparedQueries ,
1343
+ EmitSingleFile : golang .EmitSingleFile ,
1258
1344
Q : "`" ,
1259
1345
Package : golang .Package ,
1260
1346
GoQueries : r .GoQueries (settings ),
@@ -1264,11 +1350,11 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
1264
1350
1265
1351
output := map [string ]string {}
1266
1352
1267
- execute := func (name string , t * template. Template ) error {
1353
+ execute := func (name , templateName string ) error {
1268
1354
var b bytes.Buffer
1269
1355
w := bufio .NewWriter (& b )
1270
1356
tctx .SourceName = name
1271
- err := t . Execute (w , tctx )
1357
+ err := tmpl . ExecuteTemplate (w , templateName , & tctx )
1272
1358
w .Flush ()
1273
1359
if err != nil {
1274
1360
return err
@@ -1285,14 +1371,22 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
1285
1371
return nil
1286
1372
}
1287
1373
1288
- if err := execute ("db.go" , dbFile ); err != nil {
1374
+ // Output a single file with all code
1375
+ if golang .EmitSingleFile {
1376
+ if err := execute ("db.go" , "singleFile" ); err != nil {
1377
+ return nil , err
1378
+ }
1379
+ return output , nil
1380
+ }
1381
+
1382
+ if err := execute ("db.go" , "dbFile" ); err != nil {
1289
1383
return nil , err
1290
1384
}
1291
- if err := execute ("models.go" , modelsFile ); err != nil {
1385
+ if err := execute ("models.go" , " modelsFile" ); err != nil {
1292
1386
return nil , err
1293
1387
}
1294
1388
if golang .EmitInterface {
1295
- if err := execute ("querier.go" , ifaceFile ); err != nil {
1389
+ if err := execute ("querier.go" , "interfaceFile" ); err != nil {
1296
1390
return nil , err
1297
1391
}
1298
1392
}
@@ -1303,7 +1397,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
1303
1397
}
1304
1398
1305
1399
for source := range files {
1306
- if err := execute (source , sqlFile ); err != nil {
1400
+ if err := execute (source , "queryFile" ); err != nil {
1307
1401
return nil , err
1308
1402
}
1309
1403
}
0 commit comments