Skip to content

fix(compiler): Use common params struct field for same named params #1296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool {
func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) {
enums := buildEnums(r, settings)
structs := buildStructs(r, settings)
queries := buildQueries(r, settings, structs)
queries, err := buildQueries(r, settings, structs)
if err != nil {
return nil, err
}
return generate(settings, enums, structs, queries)
}

Expand Down
15 changes: 15 additions & 0 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ func (v *QueryValue) ReturnName() string {
return v.Name
}

func (v QueryValue) UniqueFields() []Field {
seen := map[string]struct{}{}
fields := make([]Field, 0, len(v.Struct.Fields))

for _, field := range v.Struct.Fields {
if _, found := seen[field.Name]; found {
continue
}
seen[field.Name] = struct{}{}
fields = append(fields, field)
}

return fields
}

func (v QueryValue) Params() string {
if v.isEmpty() {
return ""
Expand Down
71 changes: 59 additions & 12 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func argName(name string) string {
return out
}

func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query {
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) {
qs := make([]Query, 0, len(r.Queries))
for _, query := range r.Queries {
if query.Name == "" {
Expand Down Expand Up @@ -178,11 +178,15 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
Column: p.Column,
})
}
s, err := columnsToStruct(r, gq.MethodName+"Params", cols, settings, false)
if err != nil {
return nil, err
}
gq.Arg = QueryValue{
Emit: true,
Name: "arg",
Struct: columnsToStruct(r, gq.MethodName+"Params", cols, settings, false),
SQLPackage: sqlpkg,
Emit: true,
Name: "arg",
Struct: s,
SQLPackage: sqlpkg,
EmitPointer: settings.Go.EmitParamsStructPointers,
}
}
Expand Down Expand Up @@ -226,7 +230,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
Column: c,
})
}
gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true)
var err error
gs, err = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true)
if err != nil {
return nil, err
}
emit = true
}
gq.Ret = QueryValue{
Expand All @@ -241,7 +249,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
qs = append(qs, gq)
}
sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
return qs
return qs, nil
}

// It's possible that this method will generate duplicate JSON tag values
Expand All @@ -251,11 +259,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
// JSON tags: count, count_2, count_2
//
// This is unlikely to happen, so don't fix it yet
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) *Struct {
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) (*Struct, error) {
gs := Struct{
Name: name,
}
seen := map[string]int{}
seen := map[string][]int{}
suffixes := map[int]int{}
for i, c := range columns {
colName := columnName(c.Column, i)
Expand All @@ -267,7 +275,7 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
suffix := 0
if o, ok := suffixes[c.id]; ok && useID {
suffix = o
} else if v := seen[fieldName]; v > 0 {
} else if v := len(seen[fieldName]); v > 0 && !c.IsNamedParam {
suffix = v + 1
}
suffixes[c.id] = suffix
Expand All @@ -287,8 +295,47 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
Type: goType(r, c.Column, settings),
Tags: tags,
})
seen[baseFieldName]++
if _, found := seen[baseFieldName]; !found {
seen[baseFieldName] = []int{i}
} else {
seen[baseFieldName] = append(seen[baseFieldName], i)
}
}

return &gs
// If a field does not have a known type, but another
// field with the same name has a known type, assign
// the known type to the field without a known type
for i, field := range gs.Fields {
if len(seen[field.Name]) > 1 && field.Type == "interface{}" {
for _, j := range seen[field.Name] {
if i == j {
continue
}
otherField := gs.Fields[j]
if otherField.Type != field.Type {
field.Type = otherField.Type
}
gs.Fields[i] = field
}
}
}

err := checkIncompatibleFieldTypes(gs.Fields)
if err != nil {
return nil, err
}

return &gs, nil
}

func checkIncompatibleFieldTypes(fields []Field) error {
fieldTypes := map[string]string{}
for _, field := range fields {
if fieldType, found := fieldTypes[field.Name]; !found {
fieldTypes[field.Name] = field.Type
} else if field.Type != fieldType {
return fmt.Errorf("named param %s has incompatible types: %s, %s", field.Name, field.Type, fieldType)
}
}
return nil
}
2 changes: 1 addition & 1 deletion internal/codegen/golang/templates/stdlib/queryCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
{{$.Q}}

{{if .Arg.EmitStruct}}
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}}
{{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
{{- end}}
}
Expand Down
13 changes: 7 additions & 6 deletions internal/compiler/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ type Table struct {
}

type Column struct {
Name string
DataType string
NotNull bool
IsArray bool
Comment string
Length *int
Name string
DataType string
NotNull bool
IsArray bool
Comment string
Length *int
IsNamedParam bool

// XXX: Figure out what PostgreSQL calls `foo.id`
Scope string
Expand Down
84 changes: 49 additions & 35 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
return defaultName
}

isNamedParam := func(n int) bool {
_, ok := names[n]
return ok
}

typeMap := map[string]map[string]map[string]*catalog.Column{}
indexTable := func(table catalog.Table) error {
tables = append(tables, table.Rel)
Expand Down Expand Up @@ -88,19 +93,21 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, "offset"),
DataType: "integer",
NotNull: true,
Name: parameterName(ref.ref.Number, "offset"),
DataType: "integer",
NotNull: true,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})

case *limitCount:
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, "limit"),
DataType: "integer",
NotNull: true,
Name: parameterName(ref.ref.Number, "limit"),
DataType: "integer",
NotNull: true,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})

Expand All @@ -121,8 +128,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, ""),
DataType: dataType,
Name: parameterName(ref.ref.Number, ""),
DataType: dataType,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
continue
Expand Down Expand Up @@ -178,12 +186,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Length: c.Length,
Table: table,
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Length: c.Length,
Table: table,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
}
Expand Down Expand Up @@ -234,11 +243,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: number,
Column: &Column{
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: table,
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: table,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
}
Expand Down Expand Up @@ -300,8 +310,9 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, defaultName),
DataType: "any",
Name: parameterName(ref.ref.Number, defaultName),
DataType: "any",
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
continue
Expand Down Expand Up @@ -330,9 +341,10 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, paramName),
DataType: dataType(paramType),
NotNull: true,
Name: parameterName(ref.ref.Number, paramName),
DataType: dataType(paramType),
NotNull: true,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
}
Expand Down Expand Up @@ -388,12 +400,13 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: ref.ref.Number,
Column: &Column{
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: &ast.TableName{Schema: schema, Name: rel},
Length: c.Length,
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: &ast.TableName{Schema: schema, Name: rel},
Length: c.Length,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
} else {
Expand Down Expand Up @@ -488,11 +501,12 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
a = append(a, Parameter{
Number: number,
Column: &Column{
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: table,
Name: parameterName(ref.ref.Number, key),
DataType: dataType(&c.Type),
NotNull: c.IsNotNull,
IsArray: c.IsArray,
Table: table,
IsNamedParam: isNamedParam(ref.ref.Number),
},
})
}
Expand Down
29 changes: 29 additions & 0 deletions internal/endtoend/testdata/case_named_params/mysql/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions internal/endtoend/testdata/case_named_params/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading