Skip to content

Commit 8ae6b48

Browse files
authored
codegen/golang: Make sure to import net package (#858)
Add a map of known types to always check
1 parent cbfad2a commit 8ae6b48

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

internal/codegen/golang/imports.go

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ func (i *importer) dbImports() fileImports {
111111
return fileImports{Std: std}
112112
}
113113

114+
var stdlibTypes = map[string]string{
115+
"json.RawMessage": "encoding/json",
116+
"time.Time": "time",
117+
"net.IP": "net",
118+
"net.HardwareAddr": "net",
119+
}
120+
114121
func (i *importer) interfaceImports() fileImports {
115122
uses := func(name string) bool {
116123
for _, q := range i.Queries {
@@ -139,17 +146,10 @@ func (i *importer) interfaceImports() fileImports {
139146
std["database/sql"] = struct{}{}
140147
}
141148
}
142-
if uses("json.RawMessage") {
143-
std["encoding/json"] = struct{}{}
144-
}
145-
if uses("time.Time") {
146-
std["time"] = struct{}{}
147-
}
148-
if uses("net.IP") {
149-
std["net"] = struct{}{}
150-
}
151-
if uses("net.HardwareAddr") {
152-
std["net"] = struct{}{}
149+
for typeName, pkg := range stdlibTypes {
150+
if uses(typeName) {
151+
std[pkg] = struct{}{}
152+
}
153153
}
154154

155155
pkg := make(map[ImportSpec]struct{})
@@ -202,17 +202,10 @@ func (i *importer) modelImports() fileImports {
202202
if i.usesType("sql.Null") {
203203
std["database/sql"] = struct{}{}
204204
}
205-
if i.usesType("json.RawMessage") {
206-
std["encoding/json"] = struct{}{}
207-
}
208-
if i.usesType("time.Time") {
209-
std["time"] = struct{}{}
210-
}
211-
if i.usesType("net.IP") {
212-
std["net"] = struct{}{}
213-
}
214-
if i.usesType("net.HardwareAddr") {
215-
std["net"] = struct{}{}
205+
for typeName, pkg := range stdlibTypes {
206+
if i.usesType(typeName) {
207+
std[pkg] = struct{}{}
208+
}
216209
}
217210
if len(i.Enums) > 0 {
218211
std["fmt"] = struct{}{}
@@ -347,14 +340,10 @@ func (i *importer) queryImports(filename string) fileImports {
347340
std["database/sql"] = struct{}{}
348341
}
349342
}
350-
if uses("json.RawMessage") {
351-
std["encoding/json"] = struct{}{}
352-
}
353-
if uses("time.Time") {
354-
std["time"] = struct{}{}
355-
}
356-
if uses("net.IP") {
357-
std["net"] = struct{}{}
343+
for typeName, pkg := range stdlibTypes {
344+
if uses(typeName) {
345+
std[pkg] = struct{}{}
346+
}
358347
}
359348

360349
pkg := make(map[ImportSpec]struct{})

internal/endtoend/testdata/macaddr/go/query.sql.go

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/macaddr/query.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ CREATE TABLE foo (bar bool not null, addr macaddr not null);
22

33
-- name: Get :many
44
SELECT bar, addr FROM foo LIMIT $1;
5+
6+
-- name: GetAddr :many
7+
SELECT addr FROM foo LIMIT $1;

0 commit comments

Comments
 (0)