Skip to content

Commit a7192da

Browse files
committed
refactor: Remove Overrides from codegen.proto, pass in opts as JSON
The `Overrides` configuration only applies to Go, so we remove it from the codegen plugin proto and instead pass the configuration as JSON to be unmarshaled by the codegen/golang package like other options. Also pull codegen/golang opts into its own package since its about to get a lot more complex. Eventually we will need to push all of the validation of overrides from internal/config into the codegen/golang package, but that's a fair bit of work so I didn't push it here. A lot of the changes in this diff are just from pushing the `opts` type within codegen/golang into a package, which is necessary eventually but pretty distracting right now so sorry about that. Aside from the proto change, which is obviously important, the other meaningful change is to generate.go on line 421, and the function implementation in shim.go. The types in opts/override.go are just recreations of what the plugin proto types were.
1 parent 8e06874 commit a7192da

File tree

16 files changed

+512
-1989
lines changed

16 files changed

+512
-1989
lines changed

internal/cmd/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
418418
case sql.Gen.Go != nil:
419419
out = combo.Go.Out
420420
handler = ext.HandleFunc(golang.Generate)
421-
opts, err := json.Marshal(sql.Gen.Go)
421+
opts, err := json.Marshal(pluginGoOpts(sql.Gen.Go, combo, result))
422422
if err != nil {
423423
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
424424
}

internal/cmd/shim.go

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cmd
33
import (
44
"strings"
55

6+
goopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
67
"github.com/sqlc-dev/sqlc/internal/compiler"
78
"github.com/sqlc-dev/sqlc/internal/config"
89
"github.com/sqlc-dev/sqlc/internal/config/convert"
@@ -11,7 +12,7 @@ import (
1112
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
1213
)
1314

14-
func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
15+
func pluginOverride(r *compiler.Result, o config.Override) goopts.Override {
1516
var column string
1617
var table plugin.Identifier
1718

@@ -33,7 +34,7 @@ func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
3334
column = colParts[3]
3435
}
3536
}
36-
return &plugin.Override{
37+
return goopts.Override{
3738
CodeType: "", // FIXME
3839
DbType: o.DBType,
3940
Nullable: o.Nullable,
@@ -46,18 +47,13 @@ func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
4647
}
4748

4849
func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings {
49-
var over []*plugin.Override
50-
for _, o := range cs.Overrides {
51-
over = append(over, pluginOverride(r, o))
52-
}
5350
return &plugin.Settings{
54-
Version: cs.Global.Version,
55-
Engine: string(cs.Package.Engine),
56-
Schema: []string(cs.Package.Schema),
57-
Queries: []string(cs.Package.Queries),
58-
Overrides: over,
59-
Rename: cs.Rename,
60-
Codegen: pluginCodegen(cs, cs.Codegen),
51+
Version: cs.Global.Version,
52+
Engine: string(cs.Package.Engine),
53+
Schema: []string(cs.Package.Schema),
54+
Queries: []string(cs.Package.Queries),
55+
Rename: cs.Rename,
56+
Codegen: pluginCodegen(cs, cs.Codegen),
6157
}
6258
}
6359

@@ -101,12 +97,12 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
10197
return nil
10298
}
10399

104-
func pluginGoType(o config.Override) *plugin.ParsedGoType {
100+
func pluginGoType(o config.Override) *goopts.ParsedGoType {
105101
// Note that there is a slight mismatch between this and the
106102
// proto api. The GoType on the override is the unparsed type,
107103
// which could be a qualified path or an object, as per
108104
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
109-
return &plugin.ParsedGoType{
105+
return &goopts.ParsedGoType{
110106
ImportPath: o.GoImportPath,
111107
Package: o.GoPackage,
112108
TypeName: o.GoTypeName,
@@ -115,6 +111,46 @@ func pluginGoType(o config.Override) *plugin.ParsedGoType {
115111
}
116112
}
117113

114+
func pluginGoOpts(sqlGo *config.SQLGo, cs config.CombinedSettings, r *compiler.Result) *goopts.Options {
115+
var overrides []goopts.Override
116+
for _, o := range cs.Overrides {
117+
overrides = append(overrides, pluginOverride(r, o))
118+
}
119+
return &goopts.Options{
120+
EmitInterface: sqlGo.EmitInterface,
121+
EmitJsonTags: sqlGo.EmitJSONTags,
122+
JsonTagsIdUppercase: sqlGo.JsonTagsIDUppercase,
123+
EmitDbTags: sqlGo.EmitDBTags,
124+
EmitPreparedQueries: sqlGo.EmitPreparedQueries,
125+
EmitExactTableNames: sqlGo.EmitExactTableNames,
126+
EmitEmptySlices: sqlGo.EmitEmptySlices,
127+
EmitExportedQueries: sqlGo.EmitExportedQueries,
128+
EmitResultStructPointers: sqlGo.EmitResultStructPointers,
129+
EmitParamsStructPointers: sqlGo.EmitParamsStructPointers,
130+
EmitMethodsWithDbArgument: sqlGo.EmitMethodsWithDBArgument,
131+
EmitPointersForNullTypes: sqlGo.EmitPointersForNullTypes,
132+
EmitEnumValidMethod: sqlGo.EmitEnumValidMethod,
133+
EmitAllEnumValues: sqlGo.EmitAllEnumValues,
134+
JsonTagsCaseStyle: sqlGo.JSONTagsCaseStyle,
135+
Package: sqlGo.Package,
136+
Out: sqlGo.Out,
137+
Overrides: overrides,
138+
// Rename intentionally omitted
139+
SqlPackage: sqlGo.SQLPackage,
140+
SqlDriver: sqlGo.SQLDriver,
141+
OutputBatchFileName: sqlGo.OutputBatchFileName,
142+
OutputDbFileName: sqlGo.OutputDBFileName,
143+
OutputModelsFileName: sqlGo.OutputModelsFileName,
144+
OutputQuerierFileName: sqlGo.OutputQuerierFileName,
145+
OutputCopyfromFileName: sqlGo.OutputCopyFromFileName,
146+
OutputFilesSuffix: sqlGo.OutputFilesSuffix,
147+
InflectionExcludeTableNames: sqlGo.InflectionExcludeTableNames,
148+
QueryParameterLimit: sqlGo.QueryParameterLimit,
149+
OmitUnusedStructs: sqlGo.OmitUnusedStructs,
150+
BuildTags: sqlGo.BuildTags,
151+
}
152+
}
153+
118154
func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
119155
var schemas []*plugin.Schema
120156
for _, s := range c.Schemas {

internal/codegen/golang/field.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sort"
77
"strings"
88

9+
"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
910
"github.com/sqlc-dev/sqlc/internal/plugin"
1011
)
1112

@@ -40,7 +41,7 @@ func TagsToString(tags map[string]string) string {
4041
return strings.Join(tagParts, " ")
4142
}
4243

43-
func JSONTagName(name string, options *opts) string {
44+
func JSONTagName(name string, options *opts.Options) string {
4445
style := options.JsonTagsCaseStyle
4546
idUppercase := options.JsonTagsIdUppercase
4647
if style == "" || style == "none" {

internal/codegen/golang/gen.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"text/template"
1212

13+
"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
1314
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
1415
"github.com/sqlc-dev/sqlc/internal/metadata"
1516
"github.com/sqlc-dev/sqlc/internal/plugin"
@@ -103,12 +104,12 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
103104
}
104105

105106
func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
106-
options, err := parseOpts(req)
107+
options, err := opts.ParseOpts(req)
107108
if err != nil {
108109
return nil, err
109110
}
110111

111-
if err := validateOpts(options); err != nil {
112+
if err := opts.ValidateOpts(options); err != nil {
112113
return nil, err
113114
}
114115

@@ -126,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
126127
return generate(req, options, enums, structs, queries)
127128
}
128129

129-
func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
130+
func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
130131
i := &importer{
131132
Settings: req.Settings,
132133
Options: options,

internal/codegen/golang/go_type.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@ package golang
33
import (
44
"strings"
55

6+
"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
67
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
78
"github.com/sqlc-dev/sqlc/internal/plugin"
89
)
910

10-
func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, col *plugin.Column) {
11-
for _, oride := range req.Settings.Overrides {
11+
func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
12+
for _, oride := range options.Overrides {
1213
if oride.GoType.StructTags == nil {
1314
continue
1415
}
15-
if !sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema) {
16+
if !oride.Matches(col.Table, req.Catalog.DefaultSchema) {
1617
// Different table.
1718
continue
1819
}
@@ -31,17 +32,17 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co
3132
}
3233
}
3334

34-
func goType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string {
35+
func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
3536
// Check if the column's type has been overridden
36-
for _, oride := range req.Settings.Overrides {
37+
for _, oride := range options.Overrides {
3738
if oride.GoType.TypeName == "" {
3839
continue
3940
}
4041
cname := col.Name
4142
if col.OriginalName != "" {
4243
cname = col.OriginalName
4344
}
44-
sameTable := sdk.Matches(oride, col.Table, req.Catalog.DefaultSchema)
45+
sameTable := oride.Matches(col.Table, req.Catalog.DefaultSchema)
4546
if oride.Column != "" && sdk.MatchString(oride.ColumnName, cname) && sameTable {
4647
if col.IsSqlcSlice {
4748
return "[]" + oride.GoType.TypeName
@@ -59,12 +60,12 @@ func goType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) strin
5960
return typ
6061
}
6162

62-
func goInnerType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string {
63+
func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
6364
columnType := sdk.DataType(col.Type)
6465
notNull := col.NotNull || col.IsArray
6566

6667
// package overrides have a higher precedence
67-
for _, oride := range req.Settings.Overrides {
68+
for _, oride := range options.Overrides {
6869
if oride.GoType.TypeName == "" {
6970
continue
7071
}

internal/codegen/golang/imports.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"sort"
66
"strings"
77

8+
"github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
89
"github.com/sqlc-dev/sqlc/internal/metadata"
910
"github.com/sqlc-dev/sqlc/internal/plugin"
1011
)
@@ -59,7 +60,7 @@ func mergeImports(imps ...fileImports) [][]ImportSpec {
5960

6061
type importer struct {
6162
Settings *plugin.Settings
62-
Options *opts
63+
Options *opts.Options
6364
Queries []Query
6465
Enums []Enum
6566
Structs []Struct
@@ -156,7 +157,7 @@ var pqtypeTypes = map[string]struct{}{
156157
"pqtype.NullRawMessage": {},
157158
}
158159

159-
func buildImports(settings *plugin.Settings, options *opts, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
160+
func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
160161
pkg := make(map[ImportSpec]struct{})
161162
std := make(map[string]struct{})
162163

@@ -200,7 +201,7 @@ func buildImports(settings *plugin.Settings, options *opts, queries []Query, use
200201
}
201202

202203
overrideTypes := map[string]string{}
203-
for _, o := range settings.Overrides {
204+
for _, o := range options.Overrides {
204205
if o.GoType.BasicType || o.GoType.TypeName == "" {
205206
continue
206207
}
@@ -225,7 +226,7 @@ func buildImports(settings *plugin.Settings, options *opts, queries []Query, use
225226
}
226227

227228
// Custom imports
228-
for _, o := range settings.Overrides {
229+
for _, o := range options.Overrides {
229230
if o.GoType.BasicType || o.GoType.TypeName == "" {
230231
continue
231232
}
@@ -240,7 +241,7 @@ func buildImports(settings *plugin.Settings, options *opts, queries []Query, use
240241
}
241242

242243
func (i *importer) interfaceImports() fileImports {
243-
std, pkg := buildImports(i.Settings, i.Options, i.Queries, func(name string) bool {
244+
std, pkg := buildImports(i.Options, i.Queries, func(name string) bool {
244245
for _, q := range i.Queries {
245246
if q.hasRetType() {
246247
if usesBatch([]Query{q}) {
@@ -265,7 +266,7 @@ func (i *importer) interfaceImports() fileImports {
265266
}
266267

267268
func (i *importer) modelImports() fileImports {
268-
std, pkg := buildImports(i.Settings, i.Options, nil, i.usesType)
269+
std, pkg := buildImports(i.Options, nil, i.usesType)
269270

270271
if len(i.Enums) > 0 {
271272
std["fmt"] = struct{}{}
@@ -304,7 +305,7 @@ func (i *importer) queryImports(filename string) fileImports {
304305
}
305306
}
306307

307-
std, pkg := buildImports(i.Settings, i.Options, gq, func(name string) bool {
308+
std, pkg := buildImports(i.Options, gq, func(name string) bool {
308309
for _, q := range gq {
309310
if q.hasRetType() {
310311
if q.Ret.EmitStruct() {
@@ -405,7 +406,7 @@ func (i *importer) copyfromImports() fileImports {
405406
copyFromQueries = append(copyFromQueries, q)
406407
}
407408
}
408-
std, pkg := buildImports(i.Settings, i.Options, copyFromQueries, func(name string) bool {
409+
std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool {
409410
for _, q := range copyFromQueries {
410411
if q.hasRetType() {
411412
if strings.HasPrefix(q.Ret.Type(), name) {
@@ -440,7 +441,7 @@ func (i *importer) batchImports() fileImports {
440441
batchQueries = append(batchQueries, q)
441442
}
442443
}
443-
std, pkg := buildImports(i.Settings, i.Options, batchQueries, func(name string) bool {
444+
std, pkg := buildImports(i.Options, batchQueries, func(name string) bool {
444445
for _, q := range batchQueries {
445446
if q.hasRetType() {
446447
if q.Ret.EmitStruct() {

internal/codegen/golang/opts.go

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)