diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 63900ba763..8ff51eb2c6 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -111,6 +111,13 @@ func (i *importer) dbImports() fileImports { return fileImports{Std: std} } +var stdlibTypes = map[string]string{ + "json.RawMessage": "encoding/json", + "time.Time": "time", + "net.IP": "net", + "net.HardwareAddr": "net", +} + func (i *importer) interfaceImports() fileImports { uses := func(name string) bool { for _, q := range i.Queries { @@ -139,17 +146,10 @@ func (i *importer) interfaceImports() fileImports { std["database/sql"] = struct{}{} } } - if uses("json.RawMessage") { - std["encoding/json"] = struct{}{} - } - if uses("time.Time") { - std["time"] = struct{}{} - } - if uses("net.IP") { - std["net"] = struct{}{} - } - if uses("net.HardwareAddr") { - std["net"] = struct{}{} + for typeName, pkg := range stdlibTypes { + if uses(typeName) { + std[pkg] = struct{}{} + } } pkg := make(map[ImportSpec]struct{}) @@ -202,17 +202,10 @@ func (i *importer) modelImports() fileImports { if i.usesType("sql.Null") { std["database/sql"] = struct{}{} } - if i.usesType("json.RawMessage") { - std["encoding/json"] = struct{}{} - } - if i.usesType("time.Time") { - std["time"] = struct{}{} - } - if i.usesType("net.IP") { - std["net"] = struct{}{} - } - if i.usesType("net.HardwareAddr") { - std["net"] = struct{}{} + for typeName, pkg := range stdlibTypes { + if i.usesType(typeName) { + std[pkg] = struct{}{} + } } if len(i.Enums) > 0 { std["fmt"] = struct{}{} @@ -347,14 +340,10 @@ func (i *importer) queryImports(filename string) fileImports { std["database/sql"] = struct{}{} } } - if uses("json.RawMessage") { - std["encoding/json"] = struct{}{} - } - if uses("time.Time") { - std["time"] = struct{}{} - } - if uses("net.IP") { - std["net"] = struct{}{} + for typeName, pkg := range stdlibTypes { + if uses(typeName) { + std[pkg] = struct{}{} + } } pkg := make(map[ImportSpec]struct{}) diff --git a/internal/endtoend/testdata/macaddr/go/query.sql.go b/internal/endtoend/testdata/macaddr/go/query.sql.go index a2adf75b26..ce14bdd56d 100644 --- a/internal/endtoend/testdata/macaddr/go/query.sql.go +++ b/internal/endtoend/testdata/macaddr/go/query.sql.go @@ -5,6 +5,7 @@ package querytest import ( "context" + "net" ) const get = `-- name: Get :many @@ -33,3 +34,30 @@ func (q *Queries) Get(ctx context.Context, limit int32) ([]Foo, error) { } return items, nil } + +const getAddr = `-- name: GetAddr :many +SELECT addr FROM foo LIMIT $1 +` + +func (q *Queries) GetAddr(ctx context.Context, limit int32) ([]net.HardwareAddr, error) { + rows, err := q.db.QueryContext(ctx, getAddr, limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []net.HardwareAddr + for rows.Next() { + var addr net.HardwareAddr + if err := rows.Scan(&addr); err != nil { + return nil, err + } + items = append(items, addr) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/macaddr/query.sql b/internal/endtoend/testdata/macaddr/query.sql index f9edc95598..4b0fa2203e 100644 --- a/internal/endtoend/testdata/macaddr/query.sql +++ b/internal/endtoend/testdata/macaddr/query.sql @@ -2,3 +2,6 @@ CREATE TABLE foo (bar bool not null, addr macaddr not null); -- name: Get :many SELECT bar, addr FROM foo LIMIT $1; + +-- name: GetAddr :many +SELECT addr FROM foo LIMIT $1;