Skip to content

Commit b83f8da

Browse files
authored
fix(codegen/golang): Refactor imports code to match templates (#2709)
1 parent 3856ef3 commit b83f8da

File tree

6 files changed

+79
-31
lines changed

6 files changed

+79
-31
lines changed

internal/codegen/golang/gen.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ func usesBatch(queries []Query) bool {
303303

304304
func checkNoTimesForMySQLCopyFrom(queries []Query) error {
305305
for _, q := range queries {
306-
for _, f := range q.Arg.Fields() {
306+
for _, f := range q.Arg.CopyFromMySQLFields() {
307307
if f.Type == "time.Time" {
308308
return fmt.Errorf("values with a timezone are not yet supported")
309309
}

internal/codegen/golang/imports.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,9 @@ func (i *importer) interfaceImports() fileImports {
242242
return true
243243
}
244244
}
245-
if !q.Arg.isEmpty() {
246-
for _, f := range q.Arg.Fields() {
247-
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
248-
return true
249-
}
245+
for _, f := range q.Arg.Pairs() {
246+
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
247+
return true
250248
}
251249
}
252250
}
@@ -312,13 +310,20 @@ func (i *importer) queryImports(filename string) fileImports {
312310
return true
313311
}
314312
}
315-
if !q.Arg.isEmpty() {
316-
for _, f := range q.Arg.Fields() {
313+
// Check the fields of the argument struct if it's emitted
314+
if q.Arg.EmitStruct() {
315+
for _, f := range q.Arg.Struct.Fields {
317316
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
318317
return true
319318
}
320319
}
321320
}
321+
// Check the argument pairs inside the method definition
322+
for _, f := range q.Arg.Pairs() {
323+
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
324+
return true
325+
}
326+
}
322327
}
323328
return false
324329
})
@@ -441,15 +446,15 @@ func (i *importer) batchImports() fileImports {
441446
return true
442447
}
443448
}
444-
if !q.Arg.isEmpty() {
445-
if q.Arg.EmitStruct() {
446-
for _, f := range q.Arg.Struct.Fields {
447-
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
448-
return true
449-
}
449+
if q.Arg.EmitStruct() {
450+
for _, f := range q.Arg.Struct.Fields {
451+
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
452+
return true
450453
}
451454
}
452-
if hasPrefixIgnoringSliceAndPointerPrefix(q.Arg.Type(), name) {
455+
}
456+
for _, f := range q.Arg.Pairs() {
457+
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
453458
return true
454459
}
455460
}

internal/codegen/golang/query.go

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,41 @@ func (v QueryValue) isEmpty() bool {
3838
return v.Typ == "" && v.Name == "" && v.Struct == nil
3939
}
4040

41+
type Argument struct {
42+
Name string
43+
Type string
44+
}
45+
4146
func (v QueryValue) Pair() string {
42-
if v.isEmpty() {
43-
return ""
47+
var out []string
48+
for _, arg := range v.Pairs() {
49+
out = append(out, arg.Name+" "+arg.Type)
4450
}
51+
return strings.Join(out, ",")
52+
}
4553

46-
var out []string
54+
// Return the argument name and type for query methods. Should only be used in
55+
// the context of method arguments.
56+
func (v QueryValue) Pairs() []Argument {
57+
if v.isEmpty() {
58+
return nil
59+
}
4760
if !v.EmitStruct() && v.IsStruct() {
61+
var out []Argument
4862
for _, f := range v.Struct.Fields {
49-
out = append(out, toLowerCase(f.Name)+" "+f.Type)
63+
out = append(out, Argument{
64+
Name: toLowerCase(f.Name),
65+
Type: f.Type,
66+
})
5067
}
51-
52-
return strings.Join(out, ",")
68+
return out
69+
}
70+
return []Argument{
71+
{
72+
Name: v.Name,
73+
Type: v.DefineType(),
74+
},
5375
}
54-
55-
return v.Name + " " + v.DefineType()
5676
}
5777

5878
func (v QueryValue) SlicePair() string {
@@ -202,7 +222,11 @@ func (v QueryValue) Scan() string {
202222
return "\n" + strings.Join(out, ",\n")
203223
}
204224

205-
func (v QueryValue) Fields() []Field {
225+
// Deprecated: This method does not respect the Emit field set on the
226+
// QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should
227+
// not be used other places.
228+
func (v QueryValue) CopyFromMySQLFields() []Field {
229+
// fmt.Printf("%#v\n", v)
206230
if v.Struct != nil {
207231
return v.Struct.Fields
208232
}

internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
var readerHandlerSequenceFor{{.MethodName}} uint32 = 1
55

66
func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) {
7-
e := mysqltsv.NewEncoder(w, {{ len .Arg.Fields }}, nil)
7+
e := mysqltsv.NewEncoder(w, {{ len .Arg.CopyFromMySQLFields }}, nil)
88
for _, row := range {{.Arg.Name}} {
99
{{- with $arg := .Arg }}
10-
{{- range $arg.Fields}}
10+
{{- range $arg.CopyFromMySQLFields}}
1111
{{- if eq .Type "string"}}
12-
e.AppendString({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
12+
e.AppendString({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}})
1313
{{- else if eq .Type "[]byte"}}
14-
e.AppendBytes({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
14+
e.AppendBytes({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}})
1515
{{- else}}
16-
e.AppendValue({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}})
16+
e.AppendValue({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}})
1717
{{- end}}
1818
{{- end}}
1919
{{- end}}

internal/endtoend/testdata/query_parameter_limit_to_zero/postgresql/go/querier.go

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/query_parameter_limit_to_zero/postgresql/sqlc.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
"name": "querytest",
88
"schema": "query.sql",
99
"queries": "query.sql",
10-
"query_parameter_limit": 0
10+
"query_parameter_limit": 0,
11+
"emit_interface": true
1112
}
1213
]
1314
}
14-
15+

0 commit comments

Comments
 (0)