diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 81ea486f5c..4aa38d8cb4 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -26,7 +26,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -137,7 +137,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -175,7 +175,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -226,7 +226,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 188dc668f5..63900ba763 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -1,6 +1,7 @@ package golang import ( + "fmt" "sort" "strings" @@ -9,35 +10,51 @@ import ( ) type fileImports struct { - Std []string - Dep []string + Std []ImportSpec + Dep []ImportSpec } -func mergeImports(imps ...fileImports) [][]string { +type ImportSpec struct { + ID string + Path string +} + +func (s ImportSpec) String() string { + if s.ID != "" { + return fmt.Sprintf("%s \"%s\"", s.ID, s.Path) + } else { + return fmt.Sprintf("\"%s\"", s.Path) + } +} + +func mergeImports(imps ...fileImports) [][]ImportSpec { if len(imps) == 1 { - return [][]string{imps[0].Std, imps[0].Dep} + return [][]ImportSpec{ + imps[0].Std, + imps[0].Dep, + } } - var stds, pkgs []string + var stds, pkgs []ImportSpec seenStd := map[string]struct{}{} seenPkg := map[string]struct{}{} for i := range imps { - for _, std := range imps[i].Std { - if _, ok := seenStd[std]; ok { + for _, spec := range imps[i].Std { + if _, ok := seenStd[spec.Path]; ok { continue } - stds = append(stds, std) - seenStd[std] = struct{}{} + stds = append(stds, spec) + seenStd[spec.Path] = struct{}{} } - for _, pkg := range imps[i].Dep { - if _, ok := seenPkg[pkg]; ok { + for _, spec := range imps[i].Dep { + if _, ok := seenPkg[spec.Path]; ok { continue } - pkgs = append(pkgs, pkg) - seenPkg[pkg] = struct{}{} + pkgs = append(pkgs, spec) + seenPkg[spec.Path] = struct{}{} } } - return [][]string{stds, pkgs} + return [][]ImportSpec{stds, pkgs} } type importer struct { @@ -70,7 +87,7 @@ func (i *importer) usesArrays() bool { return false } -func (i *importer) Imports(filename string) [][]string { +func (i *importer) Imports(filename string) [][]ImportSpec { switch filename { case "db.go": return mergeImports(i.dbImports()) @@ -84,9 +101,12 @@ func (i *importer) Imports(filename string) [][]string { } func (i *importer) dbImports() fileImports { - std := []string{"context", "database/sql"} + std := []ImportSpec{ + {Path: "context"}, + {Path: "database/sql"}, + } if i.Settings.Go.EmitPreparedQueries { - std = append(std, "fmt") + std = append(std, ImportSpec{Path: "fmt"}) } return fileImports{Std: std} } @@ -132,43 +152,48 @@ func (i *importer) interfaceImports() fileImports { std["net"] = struct{}{} } - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] if uses("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if uses("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } // Custom imports - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && uses(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } @@ -194,43 +219,48 @@ func (i *importer) modelImports() fileImports { } // Custom imports - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] if i.usesType("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if i.usesType("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && i.usesType(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && i.usesType(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } @@ -327,45 +357,50 @@ func (i *importer) queryImports(filename string) fileImports { std["net"] = struct{}{} } - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } if sliceScan() { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideNullTime := overrideTypes["pq.NullTime"] if uses("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if uses("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } // Custom imports - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && uses(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } diff --git a/internal/config/config.go b/internal/config/config.go index d2e7c2f987..f3cef93c33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "go/types" "io" "os" "strings" @@ -131,7 +130,7 @@ type SQLKotlin struct { type Override struct { // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` - GoType string `json:"go_type" yaml:"go_type"` + GoType GoType `json:"go_type" yaml:"go_type"` // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` DBType string `json:"db_type" yaml:"db_type"` @@ -148,11 +147,12 @@ type Override struct { // fully qualified name of the column, e.g. `accounts.id` Column string `json:"column" yaml:"column"` - ColumnName string - Table core.FQN - GoTypeName string - GoPackage string - GoBasicType bool + ColumnName string + Table core.FQN + GoImportPath string + GoPackage string + GoTypeName string + GoBasicType bool } func (o *Override) Parse() error { @@ -199,55 +199,14 @@ func (o *Override) Parse() error { } // validate GoType - lastDot := strings.LastIndex(o.GoType, ".") - lastSlash := strings.LastIndex(o.GoType, "/") - typename := o.GoType - if lastDot == -1 && lastSlash == -1 { - // if the type name has no slash and no dot, validate that the type is a basic Go type - var found bool - for _, typ := range types.Typ { - info := typ.Info() - if info == 0 { - continue - } - if info&types.IsUntyped != 0 { - continue - } - if typename == typ.Name() { - found = true - } - } - if !found { - return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoType) - } - o.GoBasicType = true - } else { - // assume the type lives in a Go package - if lastDot == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - if lastSlash == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - typename = o.GoType[lastSlash+1:] - if strings.HasPrefix(typename, "go-") { - // a package name beginning with "go-" will give syntax errors in - // generated code. We should do the right thing and get the actual - // import name, but in lieu of that, stripping the leading "go-" may get - // us what we want. - typename = typename[len("go-"):] - } - if strings.HasSuffix(typename, "-go") { - typename = typename[:len(typename)-len("-go")] - } - o.GoPackage = o.GoType[:lastDot] - } - o.GoTypeName = typename - isPointer := o.GoType[0] == '*' - if isPointer { - o.GoPackage = o.GoPackage[1:] - o.GoTypeName = "*" + o.GoTypeName + parsed, err := o.GoType.Parse() + if err != nil { + return err } + o.GoImportPath = parsed.ImportPath + o.GoPackage = parsed.Package + o.GoTypeName = parsed.TypeName + o.GoBasicType = parsed.BasicType return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ef12a178b4..3249d93398 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -74,7 +74,7 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "github.com/segmentio/ksuid.KSUID", + GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"}, }, "github.com/segmentio/ksuid", "ksuid.KSUID", @@ -94,7 +94,7 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "citext", - GoType: "string", + GoType: GoType{Spec: "string"}, }, "", "string", @@ -102,16 +102,16 @@ func TestTypeOverrides(t *testing.T) { }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoType.Spec, func(t *testing.T) { if err := tt.override.Parse(); err != nil { t.Fatalf("override parsing failed; %s", err) } + if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" { + t.Errorf("package mismatch;\n%s", diff) + } if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" { t.Errorf("type name mismatch;\n%s", diff) } - if diff := cmp.Diff(tt.pkg, tt.override.GoPackage); diff != "" { - t.Errorf("package mismatch;\n%s", diff) - } if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" { t.Errorf("basic mismatch;\n%s", diff) } @@ -124,20 +124,20 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "Pointer", + GoType: GoType{Spec: "Pointer"}, }, "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'", }, { Override{ DBType: "uuid", - GoType: "untyped rune", + GoType: GoType{Spec: "untyped rune"}, }, "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'", }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoType.Spec, func(t *testing.T) { err := tt.override.Parse() if err == nil { t.Fatalf("expected pars to fail; got nil") diff --git a/internal/config/go_type.go b/internal/config/go_type.go new file mode 100644 index 0000000000..8a2b3fa06a --- /dev/null +++ b/internal/config/go_type.go @@ -0,0 +1,160 @@ +package config + +import ( + "encoding/json" + "fmt" + "go/types" + "regexp" + "strings" +) + +type GoType struct { + Path string `json:"import" yaml:"import"` + Package string `json:"package" yaml:"package"` + Name string `json:"type" yaml:"type"` + Pointer bool `json:"pointer" yaml:"pointer"` + Spec string + BuiltIn bool +} + +type ParsedGoType struct { + ImportPath string + Package string + TypeName string + BasicType bool +} + +func (o *GoType) UnmarshalJSON(data []byte) error { + var spec string + if err := json.Unmarshal(data, &spec); err == nil { + *o = GoType{Spec: spec} + return nil + } + type alias GoType + var a alias + if err := json.Unmarshal(data, &a); err != nil { + return err + } + *o = GoType(a) + return nil +} + +func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var spec string + if err := unmarshal(&spec); err == nil { + *o = GoType{Spec: spec} + return nil + } + type alias GoType + var a alias + if err := unmarshal(&a); err != nil { + return err + } + *o = GoType(a) + return nil +} + +var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) +var versionNumber = regexp.MustCompile(`^v[0-9]+$`) +var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`) + +func generatePackageID(importPath string) (string, bool) { + parts := strings.Split(importPath, "/") + name := parts[len(parts)-1] + // If the last part of the import path is a valid identifier, assume that's the package name + if versionNumber.MatchString(name) && len(parts) >= 2 { + name = parts[len(parts)-2] + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true + } + if validIdentifier.MatchString(name) { + return name, false + } + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true +} + +// validate GoType +func (gt GoType) Parse() (*ParsedGoType, error) { + var o ParsedGoType + + if gt.Spec == "" { + // TODO: Validation + if gt.Path == "" && gt.Package != "" { + return nil, fmt.Errorf("Package override `go_type`: package name requires an import path") + } + var pkg string + var pkgNeedsAlias bool + + if gt.Package == "" && gt.Path != "" { + pkg, pkgNeedsAlias = generatePackageID(gt.Path) + if pkgNeedsAlias { + o.Package = pkg + } + } else { + pkg = gt.Package + o.Package = gt.Package + } + + o.ImportPath = gt.Path + o.TypeName = gt.Name + o.BasicType = gt.Path == "" && gt.Package == "" + if pkg != "" { + o.TypeName = pkg + "." + o.TypeName + } + if gt.Pointer { + o.TypeName = "*" + o.TypeName + } + return &o, nil + } + + input := gt.Spec + lastDot := strings.LastIndex(input, ".") + lastSlash := strings.LastIndex(input, "/") + typename := input + if lastDot == -1 && lastSlash == -1 { + // if the type name has no slash and no dot, validate that the type is a basic Go type + var found bool + for _, typ := range types.Typ { + info := typ.Info() + if info == 0 { + continue + } + if info&types.IsUntyped != 0 { + continue + } + if typename == typ.Name() { + found = true + } + } + if !found { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input) + } + o.BasicType = true + } else { + // assume the type lives in a Go package + if lastDot == -1 { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) + } + if lastSlash == -1 { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) + } + typename = input[lastSlash+1:] + if strings.HasPrefix(typename, "go-") { + // a package name beginning with "go-" will give syntax errors in + // generated code. We should do the right thing and get the actual + // import name, but in lieu of that, stripping the leading "go-" may get + // us what we want. + typename = typename[len("go-"):] + } + if strings.HasSuffix(typename, "-go") { + typename = typename[:len(typename)-len("-go")] + } + o.ImportPath = input[:lastDot] + } + o.TypeName = typename + isPointer := input[0] == '*' + if isPointer { + o.ImportPath = o.ImportPath[1:] + o.TypeName = "*" + o.TypeName + } + return &o, nil +} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go b/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go new file mode 100644 index 0000000000..6f21c94ed1 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go new file mode 100644 index 0000000000..9b987955d3 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "github.com/kyleconroy/sqlc-testdata/pkg" +) + +type Foo struct { + Other string + Total int64 + Retyped pkg.CustomType +} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/query.sql b/internal/endtoend/testdata/overrides_go_types/mysql/query.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/query.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql new file mode 100644 index 0000000000..c0c5fc47dc --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE foo ( + other text NOT NULL, + total bigint NOT NULL, + retyped text NOT NULL +); diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json new file mode 100644 index 0000000000..23cd7caffd --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json @@ -0,0 +1,18 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "override", + "engine": "mysql:beta", + "schema": "schema.sql", + "queries": "query.sql", + "overrides": [ + { + "go_type": "github.com/kyleconroy/sqlc-testdata/pkg.CustomType", + "column": "foo.retyped" + } + ] + } + ] +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go new file mode 100644 index 0000000000..6f21c94ed1 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go new file mode 100644 index 0000000000..6a07b893c1 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go @@ -0,0 +1,20 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + orm "database/sql" + "github.com/gofrs/uuid" + fuid "github.com/gofrs/uuid" + null "github.com/volatiletech/null/v8" + null_v4 "gopkg.in/guregu/null.v4" +) + +type Foo struct { + ID uuid.UUID + OtherID fuid.UUID + Age orm.NullInt32 + Balance null.Float32 + Bio null_v4.String + About *string +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go new file mode 100644 index 0000000000..75e73b73d0 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go @@ -0,0 +1,44 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package override + +import ( + "context" + + "github.com/gofrs/uuid" +) + +const loadFoo = `-- name: LoadFoo :many +SELECT id, other_id, age, balance, bio, about FROM foo WHERE id = $1 +` + +func (q *Queries) LoadFoo(ctx context.Context, id uuid.UUID) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, loadFoo, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan( + &i.ID, + &i.OtherID, + &i.Age, + &i.Balance, + &i.Bio, + &i.About, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql new file mode 100644 index 0000000000..192dc6cb94 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql @@ -0,0 +1,2 @@ +-- name: LoadFoo :many +SELECT * FROM foo WHERE id = $1; diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql new file mode 100644 index 0000000000..4e0d1f5af7 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql @@ -0,0 +1,8 @@ +CREATE TABLE foo ( + id uuid NOT NULL, + other_id uuid NOT NULL, + age integer, + balance double, + bio text, + about text +); diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json new file mode 100644 index 0000000000..8d3b0c2223 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json @@ -0,0 +1,62 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "override", + "engine": "postgresql", + "schema": "schema.sql", + "queries": "query.sql", + "overrides": [ + { + "column": "foo.id", + "go_type": { + "import": "github.com/gofrs/uuid", + "type": "UUID" + }, + }, + { + "column": "foo.other_id", + "go_type": { + "import": "github.com/gofrs/uuid", + "package": "fuid", + "type": "UUID" + }, + }, + { + "column": "foo.age", + "nullable": true, + "go_type": { + "import": "database/sql", + "package": "orm", + "type": "NullInt32" + }, + }, + { + "column": "foo.balance", + "nullable": true, + "go_type": { + "import": "github.com/volatiletech/null/v8", + "type": "Float32" + }, + }, + { + "column": "foo.bio", + "nullable": true, + "go_type": { + "import": "gopkg.in/guregu/null.v4", + "type": "String" + }, + }, + { + "column": "foo.about", + "nullable": true, + "go_type": { + "type": "string", + "pointer": true + } + } + ] + } + ] +}