Skip to content

Commit 9b12e81

Browse files
authored
Support complex import paths in overrides (#786)
* Refactor go_type parsing into function * Add tests for two types of overrides * Handle edge cases * remove println stmt * Fix imports for new-style overrides
1 parent f9fd0b6 commit 9b12e81

File tree

16 files changed

+516
-131
lines changed

16 files changed

+516
-131
lines changed

internal/codegen/golang/gen.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ package {{.Package}}
2626
2727
import (
2828
{{range imports .SourceName}}
29-
{{range .}}"{{.}}"
29+
{{range .}}{{.}}
3030
{{end}}
3131
{{end}}
3232
)
@@ -137,7 +137,7 @@ package {{.Package}}
137137
138138
import (
139139
{{range imports .SourceName}}
140-
{{range .}}"{{.}}"
140+
{{range .}}{{.}}
141141
{{end}}
142142
{{end}}
143143
)
@@ -175,7 +175,7 @@ package {{.Package}}
175175
176176
import (
177177
{{range imports .SourceName}}
178-
{{range .}}"{{.}}"
178+
{{range .}}{{.}}
179179
{{end}}
180180
{{end}}
181181
)
@@ -226,7 +226,7 @@ package {{.Package}}
226226
227227
import (
228228
{{range imports .SourceName}}
229-
{{range .}}"{{.}}"
229+
{{range .}}{{.}}
230230
{{end}}
231231
{{end}}
232232
)

internal/codegen/golang/imports.go

Lines changed: 98 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package golang
22

33
import (
4+
"fmt"
45
"sort"
56
"strings"
67

@@ -9,35 +10,51 @@ import (
910
)
1011

1112
type fileImports struct {
12-
Std []string
13-
Dep []string
13+
Std []ImportSpec
14+
Dep []ImportSpec
1415
}
1516

16-
func mergeImports(imps ...fileImports) [][]string {
17+
type ImportSpec struct {
18+
ID string
19+
Path string
20+
}
21+
22+
func (s ImportSpec) String() string {
23+
if s.ID != "" {
24+
return fmt.Sprintf("%s \"%s\"", s.ID, s.Path)
25+
} else {
26+
return fmt.Sprintf("\"%s\"", s.Path)
27+
}
28+
}
29+
30+
func mergeImports(imps ...fileImports) [][]ImportSpec {
1731
if len(imps) == 1 {
18-
return [][]string{imps[0].Std, imps[0].Dep}
32+
return [][]ImportSpec{
33+
imps[0].Std,
34+
imps[0].Dep,
35+
}
1936
}
2037

21-
var stds, pkgs []string
38+
var stds, pkgs []ImportSpec
2239
seenStd := map[string]struct{}{}
2340
seenPkg := map[string]struct{}{}
2441
for i := range imps {
25-
for _, std := range imps[i].Std {
26-
if _, ok := seenStd[std]; ok {
42+
for _, spec := range imps[i].Std {
43+
if _, ok := seenStd[spec.Path]; ok {
2744
continue
2845
}
29-
stds = append(stds, std)
30-
seenStd[std] = struct{}{}
46+
stds = append(stds, spec)
47+
seenStd[spec.Path] = struct{}{}
3148
}
32-
for _, pkg := range imps[i].Dep {
33-
if _, ok := seenPkg[pkg]; ok {
49+
for _, spec := range imps[i].Dep {
50+
if _, ok := seenPkg[spec.Path]; ok {
3451
continue
3552
}
36-
pkgs = append(pkgs, pkg)
37-
seenPkg[pkg] = struct{}{}
53+
pkgs = append(pkgs, spec)
54+
seenPkg[spec.Path] = struct{}{}
3855
}
3956
}
40-
return [][]string{stds, pkgs}
57+
return [][]ImportSpec{stds, pkgs}
4158
}
4259

4360
type importer struct {
@@ -70,7 +87,7 @@ func (i *importer) usesArrays() bool {
7087
return false
7188
}
7289

73-
func (i *importer) Imports(filename string) [][]string {
90+
func (i *importer) Imports(filename string) [][]ImportSpec {
7491
switch filename {
7592
case "db.go":
7693
return mergeImports(i.dbImports())
@@ -84,9 +101,12 @@ func (i *importer) Imports(filename string) [][]string {
84101
}
85102

86103
func (i *importer) dbImports() fileImports {
87-
std := []string{"context", "database/sql"}
104+
std := []ImportSpec{
105+
{Path: "context"},
106+
{Path: "database/sql"},
107+
}
88108
if i.Settings.Go.EmitPreparedQueries {
89-
std = append(std, "fmt")
109+
std = append(std, ImportSpec{Path: "fmt"})
90110
}
91111
return fileImports{Std: std}
92112
}
@@ -132,43 +152,48 @@ func (i *importer) interfaceImports() fileImports {
132152
std["net"] = struct{}{}
133153
}
134154

135-
pkg := make(map[string]struct{})
155+
pkg := make(map[ImportSpec]struct{})
136156
overrideTypes := map[string]string{}
137157
for _, o := range i.Settings.Overrides {
138158
if o.GoBasicType {
139159
continue
140160
}
141-
overrideTypes[o.GoTypeName] = o.GoPackage
161+
overrideTypes[o.GoTypeName] = o.GoImportPath
142162
}
143163

144164
_, overrideNullTime := overrideTypes["pq.NullTime"]
145165
if uses("pq.NullTime") && !overrideNullTime {
146-
pkg["github.com/lib/pq"] = struct{}{}
166+
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
147167
}
148168
_, overrideUUID := overrideTypes["uuid.UUID"]
149169
if uses("uuid.UUID") && !overrideUUID {
150-
pkg["github.com/google/uuid"] = struct{}{}
170+
pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{}
151171
}
152172

153173
// Custom imports
154-
for goType, importPath := range overrideTypes {
155-
if _, ok := std[importPath]; !ok && uses(goType) {
156-
pkg[importPath] = struct{}{}
174+
for _, o := range i.Settings.Overrides {
175+
if o.GoBasicType {
176+
continue
177+
}
178+
_, alreadyImported := std[o.GoImportPath]
179+
hasPackageAlias := o.GoPackage != ""
180+
if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) {
181+
pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{}
157182
}
158183
}
159184

160-
pkgs := make([]string, 0, len(pkg))
161-
for p, _ := range pkg {
162-
pkgs = append(pkgs, p)
185+
pkgs := make([]ImportSpec, 0, len(pkg))
186+
for spec, _ := range pkg {
187+
pkgs = append(pkgs, spec)
163188
}
164189

165-
stds := make([]string, 0, len(std))
166-
for s, _ := range std {
167-
stds = append(stds, s)
190+
stds := make([]ImportSpec, 0, len(std))
191+
for path, _ := range std {
192+
stds = append(stds, ImportSpec{Path: path})
168193
}
169194

170-
sort.Strings(stds)
171-
sort.Strings(pkgs)
195+
sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path })
196+
sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path })
172197
return fileImports{stds, pkgs}
173198
}
174199

@@ -194,43 +219,48 @@ func (i *importer) modelImports() fileImports {
194219
}
195220

196221
// Custom imports
197-
pkg := make(map[string]struct{})
222+
pkg := make(map[ImportSpec]struct{})
198223
overrideTypes := map[string]string{}
199224
for _, o := range i.Settings.Overrides {
200225
if o.GoBasicType {
201226
continue
202227
}
203-
overrideTypes[o.GoTypeName] = o.GoPackage
228+
overrideTypes[o.GoTypeName] = o.GoImportPath
204229
}
205230

206231
_, overrideNullTime := overrideTypes["pq.NullTime"]
207232
if i.usesType("pq.NullTime") && !overrideNullTime {
208-
pkg["github.com/lib/pq"] = struct{}{}
233+
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
209234
}
210235

211236
_, overrideUUID := overrideTypes["uuid.UUID"]
212237
if i.usesType("uuid.UUID") && !overrideUUID {
213-
pkg["github.com/google/uuid"] = struct{}{}
238+
pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{}
214239
}
215240

216-
for goType, importPath := range overrideTypes {
217-
if _, ok := std[importPath]; !ok && i.usesType(goType) {
218-
pkg[importPath] = struct{}{}
241+
for _, o := range i.Settings.Overrides {
242+
if o.GoBasicType {
243+
continue
244+
}
245+
_, alreadyImported := std[o.GoImportPath]
246+
hasPackageAlias := o.GoPackage != ""
247+
if (!alreadyImported || hasPackageAlias) && i.usesType(o.GoTypeName) {
248+
pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{}
219249
}
220250
}
221251

222-
pkgs := make([]string, 0, len(pkg))
223-
for p, _ := range pkg {
224-
pkgs = append(pkgs, p)
252+
pkgs := make([]ImportSpec, 0, len(pkg))
253+
for spec, _ := range pkg {
254+
pkgs = append(pkgs, spec)
225255
}
226256

227-
stds := make([]string, 0, len(std))
228-
for s, _ := range std {
229-
stds = append(stds, s)
257+
stds := make([]ImportSpec, 0, len(std))
258+
for path, _ := range std {
259+
stds = append(stds, ImportSpec{Path: path})
230260
}
231261

232-
sort.Strings(stds)
233-
sort.Strings(pkgs)
262+
sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path })
263+
sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path })
234264
return fileImports{stds, pkgs}
235265
}
236266

@@ -327,45 +357,50 @@ func (i *importer) queryImports(filename string) fileImports {
327357
std["net"] = struct{}{}
328358
}
329359

330-
pkg := make(map[string]struct{})
360+
pkg := make(map[ImportSpec]struct{})
331361
overrideTypes := map[string]string{}
332362
for _, o := range i.Settings.Overrides {
333363
if o.GoBasicType {
334364
continue
335365
}
336-
overrideTypes[o.GoTypeName] = o.GoPackage
366+
overrideTypes[o.GoTypeName] = o.GoImportPath
337367
}
338368

339369
if sliceScan() {
340-
pkg["github.com/lib/pq"] = struct{}{}
370+
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
341371
}
342372
_, overrideNullTime := overrideTypes["pq.NullTime"]
343373
if uses("pq.NullTime") && !overrideNullTime {
344-
pkg["github.com/lib/pq"] = struct{}{}
374+
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
345375
}
346376
_, overrideUUID := overrideTypes["uuid.UUID"]
347377
if uses("uuid.UUID") && !overrideUUID {
348-
pkg["github.com/google/uuid"] = struct{}{}
378+
pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{}
349379
}
350380

351381
// Custom imports
352-
for goType, importPath := range overrideTypes {
353-
if _, ok := std[importPath]; !ok && uses(goType) {
354-
pkg[importPath] = struct{}{}
382+
for _, o := range i.Settings.Overrides {
383+
if o.GoBasicType {
384+
continue
385+
}
386+
_, alreadyImported := std[o.GoImportPath]
387+
hasPackageAlias := o.GoPackage != ""
388+
if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) {
389+
pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{}
355390
}
356391
}
357392

358-
pkgs := make([]string, 0, len(pkg))
359-
for p, _ := range pkg {
360-
pkgs = append(pkgs, p)
393+
pkgs := make([]ImportSpec, 0, len(pkg))
394+
for spec, _ := range pkg {
395+
pkgs = append(pkgs, spec)
361396
}
362397

363-
stds := make([]string, 0, len(std))
364-
for s, _ := range std {
365-
stds = append(stds, s)
398+
stds := make([]ImportSpec, 0, len(std))
399+
for path, _ := range std {
400+
stds = append(stds, ImportSpec{Path: path})
366401
}
367402

368-
sort.Strings(stds)
369-
sort.Strings(pkgs)
403+
sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path })
404+
sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path })
370405
return fileImports{stds, pkgs}
371406
}

0 commit comments

Comments
 (0)