Skip to content

Commit a83dba2

Browse files
authored
fix(compiler): Use common params struct field for same named params (#1296)
1 parent c34ad5e commit a83dba2

File tree

32 files changed

+640
-86
lines changed

32 files changed

+640
-86
lines changed

internal/codegen/golang/gen.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool {
4646
func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) {
4747
enums := buildEnums(r, settings)
4848
structs := buildStructs(r, settings)
49-
queries := buildQueries(r, settings, structs)
49+
queries, err := buildQueries(r, settings, structs)
50+
if err != nil {
51+
return nil, err
52+
}
5053
return generate(settings, enums, structs, queries)
5154
}
5255

internal/codegen/golang/query.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ func (v *QueryValue) ReturnName() string {
6363
return v.Name
6464
}
6565

66+
func (v QueryValue) UniqueFields() []Field {
67+
seen := map[string]struct{}{}
68+
fields := make([]Field, 0, len(v.Struct.Fields))
69+
70+
for _, field := range v.Struct.Fields {
71+
if _, found := seen[field.Name]; found {
72+
continue
73+
}
74+
seen[field.Name] = struct{}{}
75+
fields = append(fields, field)
76+
}
77+
78+
return fields
79+
}
80+
6681
func (v QueryValue) Params() string {
6782
if v.isEmpty() {
6883
return ""

internal/codegen/golang/result.go

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func argName(name string) string {
135135
return out
136136
}
137137

138-
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query {
138+
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) {
139139
qs := make([]Query, 0, len(r.Queries))
140140
for _, query := range r.Queries {
141141
if query.Name == "" {
@@ -178,11 +178,15 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
178178
Column: p.Column,
179179
})
180180
}
181+
s, err := columnsToStruct(r, gq.MethodName+"Params", cols, settings, false)
182+
if err != nil {
183+
return nil, err
184+
}
181185
gq.Arg = QueryValue{
182-
Emit: true,
183-
Name: "arg",
184-
Struct: columnsToStruct(r, gq.MethodName+"Params", cols, settings, false),
185-
SQLPackage: sqlpkg,
186+
Emit: true,
187+
Name: "arg",
188+
Struct: s,
189+
SQLPackage: sqlpkg,
186190
EmitPointer: settings.Go.EmitParamsStructPointers,
187191
}
188192
}
@@ -226,7 +230,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
226230
Column: c,
227231
})
228232
}
229-
gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true)
233+
var err error
234+
gs, err = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true)
235+
if err != nil {
236+
return nil, err
237+
}
230238
emit = true
231239
}
232240
gq.Ret = QueryValue{
@@ -241,7 +249,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
241249
qs = append(qs, gq)
242250
}
243251
sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
244-
return qs
252+
return qs, nil
245253
}
246254

247255
// It's possible that this method will generate duplicate JSON tag values
@@ -251,11 +259,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
251259
// JSON tags: count, count_2, count_2
252260
//
253261
// This is unlikely to happen, so don't fix it yet
254-
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) *Struct {
262+
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) (*Struct, error) {
255263
gs := Struct{
256264
Name: name,
257265
}
258-
seen := map[string]int{}
266+
seen := map[string][]int{}
259267
suffixes := map[int]int{}
260268
for i, c := range columns {
261269
colName := columnName(c.Column, i)
@@ -267,7 +275,7 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
267275
suffix := 0
268276
if o, ok := suffixes[c.id]; ok && useID {
269277
suffix = o
270-
} else if v := seen[fieldName]; v > 0 {
278+
} else if v := len(seen[fieldName]); v > 0 && !c.IsNamedParam {
271279
suffix = v + 1
272280
}
273281
suffixes[c.id] = suffix
@@ -287,8 +295,47 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
287295
Type: goType(r, c.Column, settings),
288296
Tags: tags,
289297
})
290-
seen[baseFieldName]++
298+
if _, found := seen[baseFieldName]; !found {
299+
seen[baseFieldName] = []int{i}
300+
} else {
301+
seen[baseFieldName] = append(seen[baseFieldName], i)
302+
}
291303
}
292304

293-
return &gs
305+
// If a field does not have a known type, but another
306+
// field with the same name has a known type, assign
307+
// the known type to the field without a known type
308+
for i, field := range gs.Fields {
309+
if len(seen[field.Name]) > 1 && field.Type == "interface{}" {
310+
for _, j := range seen[field.Name] {
311+
if i == j {
312+
continue
313+
}
314+
otherField := gs.Fields[j]
315+
if otherField.Type != field.Type {
316+
field.Type = otherField.Type
317+
}
318+
gs.Fields[i] = field
319+
}
320+
}
321+
}
322+
323+
err := checkIncompatibleFieldTypes(gs.Fields)
324+
if err != nil {
325+
return nil, err
326+
}
327+
328+
return &gs, nil
329+
}
330+
331+
func checkIncompatibleFieldTypes(fields []Field) error {
332+
fieldTypes := map[string]string{}
333+
for _, field := range fields {
334+
if fieldType, found := fieldTypes[field.Name]; !found {
335+
fieldTypes[field.Name] = field.Type
336+
} else if field.Type != fieldType {
337+
return fmt.Errorf("named param %s has incompatible types: %s, %s", field.Name, field.Type, fieldType)
338+
}
339+
}
340+
return nil
294341
}

internal/codegen/golang/templates/stdlib/queryCode.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
66
{{$.Q}}
77

88
{{if .Arg.EmitStruct}}
9-
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
9+
type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}}
1010
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
1111
{{- end}}
1212
}

internal/compiler/query.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ type Table struct {
1515
}
1616

1717
type Column struct {
18-
Name string
19-
DataType string
20-
NotNull bool
21-
IsArray bool
22-
Comment string
23-
Length *int
18+
Name string
19+
DataType string
20+
NotNull bool
21+
IsArray bool
22+
Comment string
23+
Length *int
24+
IsNamedParam bool
2425

2526
// XXX: Figure out what PostgreSQL calls `foo.id`
2627
Scope string

internal/compiler/resolve.go

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
3131
return defaultName
3232
}
3333

34+
isNamedParam := func(n int) bool {
35+
_, ok := names[n]
36+
return ok
37+
}
38+
3439
typeMap := map[string]map[string]map[string]*catalog.Column{}
3540
indexTable := func(table catalog.Table) error {
3641
tables = append(tables, table.Rel)
@@ -88,19 +93,21 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
8893
a = append(a, Parameter{
8994
Number: ref.ref.Number,
9095
Column: &Column{
91-
Name: parameterName(ref.ref.Number, "offset"),
92-
DataType: "integer",
93-
NotNull: true,
96+
Name: parameterName(ref.ref.Number, "offset"),
97+
DataType: "integer",
98+
NotNull: true,
99+
IsNamedParam: isNamedParam(ref.ref.Number),
94100
},
95101
})
96102

97103
case *limitCount:
98104
a = append(a, Parameter{
99105
Number: ref.ref.Number,
100106
Column: &Column{
101-
Name: parameterName(ref.ref.Number, "limit"),
102-
DataType: "integer",
103-
NotNull: true,
107+
Name: parameterName(ref.ref.Number, "limit"),
108+
DataType: "integer",
109+
NotNull: true,
110+
IsNamedParam: isNamedParam(ref.ref.Number),
104111
},
105112
})
106113

@@ -121,8 +128,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
121128
a = append(a, Parameter{
122129
Number: ref.ref.Number,
123130
Column: &Column{
124-
Name: parameterName(ref.ref.Number, ""),
125-
DataType: dataType,
131+
Name: parameterName(ref.ref.Number, ""),
132+
DataType: dataType,
133+
IsNamedParam: isNamedParam(ref.ref.Number),
126134
},
127135
})
128136
continue
@@ -178,12 +186,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
178186
a = append(a, Parameter{
179187
Number: ref.ref.Number,
180188
Column: &Column{
181-
Name: parameterName(ref.ref.Number, key),
182-
DataType: dataType(&c.Type),
183-
NotNull: c.IsNotNull,
184-
IsArray: c.IsArray,
185-
Length: c.Length,
186-
Table: table,
189+
Name: parameterName(ref.ref.Number, key),
190+
DataType: dataType(&c.Type),
191+
NotNull: c.IsNotNull,
192+
IsArray: c.IsArray,
193+
Length: c.Length,
194+
Table: table,
195+
IsNamedParam: isNamedParam(ref.ref.Number),
187196
},
188197
})
189198
}
@@ -234,11 +243,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
234243
a = append(a, Parameter{
235244
Number: number,
236245
Column: &Column{
237-
Name: parameterName(ref.ref.Number, key),
238-
DataType: dataType(&c.Type),
239-
NotNull: c.IsNotNull,
240-
IsArray: c.IsArray,
241-
Table: table,
246+
Name: parameterName(ref.ref.Number, key),
247+
DataType: dataType(&c.Type),
248+
NotNull: c.IsNotNull,
249+
IsArray: c.IsArray,
250+
Table: table,
251+
IsNamedParam: isNamedParam(ref.ref.Number),
242252
},
243253
})
244254
}
@@ -300,8 +310,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
300310
a = append(a, Parameter{
301311
Number: ref.ref.Number,
302312
Column: &Column{
303-
Name: parameterName(ref.ref.Number, defaultName),
304-
DataType: "any",
313+
Name: parameterName(ref.ref.Number, defaultName),
314+
DataType: "any",
315+
IsNamedParam: isNamedParam(ref.ref.Number),
305316
},
306317
})
307318
continue
@@ -330,9 +341,10 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
330341
a = append(a, Parameter{
331342
Number: ref.ref.Number,
332343
Column: &Column{
333-
Name: parameterName(ref.ref.Number, paramName),
334-
DataType: dataType(paramType),
335-
NotNull: true,
344+
Name: parameterName(ref.ref.Number, paramName),
345+
DataType: dataType(paramType),
346+
NotNull: true,
347+
IsNamedParam: isNamedParam(ref.ref.Number),
336348
},
337349
})
338350
}
@@ -388,12 +400,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
388400
a = append(a, Parameter{
389401
Number: ref.ref.Number,
390402
Column: &Column{
391-
Name: parameterName(ref.ref.Number, key),
392-
DataType: dataType(&c.Type),
393-
NotNull: c.IsNotNull,
394-
IsArray: c.IsArray,
395-
Table: &ast.TableName{Schema: schema, Name: rel},
396-
Length: c.Length,
403+
Name: parameterName(ref.ref.Number, key),
404+
DataType: dataType(&c.Type),
405+
NotNull: c.IsNotNull,
406+
IsArray: c.IsArray,
407+
Table: &ast.TableName{Schema: schema, Name: rel},
408+
Length: c.Length,
409+
IsNamedParam: isNamedParam(ref.ref.Number),
397410
},
398411
})
399412
} else {
@@ -488,11 +501,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
488501
a = append(a, Parameter{
489502
Number: number,
490503
Column: &Column{
491-
Name: parameterName(ref.ref.Number, key),
492-
DataType: dataType(&c.Type),
493-
NotNull: c.IsNotNull,
494-
IsArray: c.IsArray,
495-
Table: table,
504+
Name: parameterName(ref.ref.Number, key),
505+
DataType: dataType(&c.Type),
506+
NotNull: c.IsNotNull,
507+
IsArray: c.IsArray,
508+
Table: table,
509+
IsNamedParam: isNamedParam(ref.ref.Number),
496510
},
497511
})
498512
}

internal/endtoend/testdata/case_named_params/mysql/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/case_named_params/mysql/go/models.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)