From 4462dae59f2ce56e8419c9a3d9d9a2490e4a6e53 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 14 Nov 2020 11:08:40 -0800 Subject: [PATCH 1/5] Refactor go_type parsing into function --- internal/config/config.go | 56 ++++-------------------------- internal/config/go_type.go | 70 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 49 deletions(-) create mode 100644 internal/config/go_type.go diff --git a/internal/config/config.go b/internal/config/config.go index d2e7c2f987..3a77ef8872 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "go/types" "io" "os" "strings" @@ -199,56 +198,15 @@ func (o *Override) Parse() error { } // validate GoType - lastDot := strings.LastIndex(o.GoType, ".") - lastSlash := strings.LastIndex(o.GoType, "/") - typename := o.GoType - if lastDot == -1 && lastSlash == -1 { - // if the type name has no slash and no dot, validate that the type is a basic Go type - var found bool - for _, typ := range types.Typ { - info := typ.Info() - if info == 0 { - continue - } - if info&types.IsUntyped != 0 { - continue - } - if typename == typ.Name() { - found = true - } - } - if !found { - return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoType) - } - o.GoBasicType = true - } else { - // assume the type lives in a Go package - if lastDot == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - if lastSlash == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - typename = o.GoType[lastSlash+1:] - if strings.HasPrefix(typename, "go-") { - // a package name beginning with "go-" will give syntax errors in - // generated code. We should do the right thing and get the actual - // import name, but in lieu of that, stripping the leading "go-" may get - // us what we want. - typename = typename[len("go-"):] - } - if strings.HasSuffix(typename, "-go") { - typename = typename[:len(typename)-len("-go")] - } - o.GoPackage = o.GoType[:lastDot] - } - o.GoTypeName = typename - isPointer := o.GoType[0] == '*' - if isPointer { - o.GoPackage = o.GoPackage[1:] - o.GoTypeName = "*" + o.GoTypeName + goType, err := ParseGoType(o.GoType) + if err != nil { + return err } + o.GoBasicType = goType.BuiltIn + o.GoPackage = goType.Path + o.GoTypeName = goType.Name + return nil } diff --git a/internal/config/go_type.go b/internal/config/go_type.go new file mode 100644 index 0000000000..75f3b29fed --- /dev/null +++ b/internal/config/go_type.go @@ -0,0 +1,70 @@ +package config + +import ( + "fmt" + "go/types" + "strings" +) + +type GoType struct { + Path string + Package string + Name string + Pointer bool + BuiltIn bool +} + +func ParseGoType(input string) (*GoType, error) { + // validate GoType + lastDot := strings.LastIndex(input, ".") + lastSlash := strings.LastIndex(input, "/") + typename := input + var o GoType + if lastDot == -1 && lastSlash == -1 { + // if the type name has no slash and no dot, validate that the type is a basic Go type + var found bool + for _, typ := range types.Typ { + info := typ.Info() + if info == 0 { + continue + } + if info&types.IsUntyped != 0 { + continue + } + if typename == typ.Name() { + found = true + } + } + if !found { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input) + } + o.BuiltIn = true + } else { + // assume the type lives in a Go package + if lastDot == -1 { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) + } + if lastSlash == -1 { + return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) + } + typename = input[lastSlash+1:] + if strings.HasPrefix(typename, "go-") { + // a package name beginning with "go-" will give syntax errors in + // generated code. We should do the right thing and get the actual + // import name, but in lieu of that, stripping the leading "go-" may get + // us what we want. + typename = typename[len("go-"):] + } + if strings.HasSuffix(typename, "-go") { + typename = typename[:len(typename)-len("-go")] + } + o.Path = input[:lastDot] + } + o.Name = typename + isPointer := input[0] == '*' + if isPointer { + o.Path = o.Path[1:] + o.Name = "*" + o.Name + } + return &o, nil +} From f7cf4623176649c70b1ee77891a9a2d5265bd7d5 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 14 Nov 2020 12:21:45 -0800 Subject: [PATCH 2/5] Add tests for two types of overrides --- internal/codegen/golang/imports.go | 6 +- internal/config/config.go | 23 ++--- internal/config/config_test.go | 18 ++-- internal/config/go_type.go | 86 ++++++++++++++++--- .../overrides_go_types/mysql/go/db.go | 29 +++++++ .../overrides_go_types/mysql/go/models.go | 13 +++ .../overrides_go_types/mysql/query.sql | 1 + .../overrides_go_types/mysql/schema.sql | 5 ++ .../overrides_go_types/mysql/sqlc.json | 18 ++++ .../overrides_go_types/postgresql/go/db.go | 29 +++++++ .../postgresql/go/models.go | 12 +++ .../postgresql/go/query.sql.go | 37 ++++++++ .../overrides_go_types/postgresql/query.sql | 2 + .../overrides_go_types/postgresql/schema.sql | 4 + .../overrides_go_types/postgresql/sqlc.json | 30 +++++++ 15 files changed, 278 insertions(+), 35 deletions(-) create mode 100644 internal/endtoend/testdata/overrides_go_types/mysql/go/db.go create mode 100644 internal/endtoend/testdata/overrides_go_types/mysql/go/models.go create mode 100644 internal/endtoend/testdata/overrides_go_types/mysql/query.sql create mode 100644 internal/endtoend/testdata/overrides_go_types/mysql/schema.sql create mode 100644 internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/query.sql create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql create mode 100644 internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 188dc668f5..7db77f922a 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -138,7 +138,7 @@ func (i *importer) interfaceImports() fileImports { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] @@ -200,7 +200,7 @@ func (i *importer) modelImports() fileImports { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] @@ -333,7 +333,7 @@ func (i *importer) queryImports(filename string) fileImports { if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoTypeName] = o.GoImportPath } if sliceScan() { diff --git a/internal/config/config.go b/internal/config/config.go index 3a77ef8872..f3cef93c33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -130,7 +130,7 @@ type SQLKotlin struct { type Override struct { // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` - GoType string `json:"go_type" yaml:"go_type"` + GoType GoType `json:"go_type" yaml:"go_type"` // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` DBType string `json:"db_type" yaml:"db_type"` @@ -147,11 +147,12 @@ type Override struct { // fully qualified name of the column, e.g. `accounts.id` Column string `json:"column" yaml:"column"` - ColumnName string - Table core.FQN - GoTypeName string - GoPackage string - GoBasicType bool + ColumnName string + Table core.FQN + GoImportPath string + GoPackage string + GoTypeName string + GoBasicType bool } func (o *Override) Parse() error { @@ -198,14 +199,14 @@ func (o *Override) Parse() error { } // validate GoType - goType, err := ParseGoType(o.GoType) + parsed, err := o.GoType.Parse() if err != nil { return err } - - o.GoBasicType = goType.BuiltIn - o.GoPackage = goType.Path - o.GoTypeName = goType.Name + o.GoImportPath = parsed.ImportPath + o.GoPackage = parsed.Package + o.GoTypeName = parsed.TypeName + o.GoBasicType = parsed.BasicType return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ef12a178b4..3249d93398 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -74,7 +74,7 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "github.com/segmentio/ksuid.KSUID", + GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"}, }, "github.com/segmentio/ksuid", "ksuid.KSUID", @@ -94,7 +94,7 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "citext", - GoType: "string", + GoType: GoType{Spec: "string"}, }, "", "string", @@ -102,16 +102,16 @@ func TestTypeOverrides(t *testing.T) { }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoType.Spec, func(t *testing.T) { if err := tt.override.Parse(); err != nil { t.Fatalf("override parsing failed; %s", err) } + if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" { + t.Errorf("package mismatch;\n%s", diff) + } if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" { t.Errorf("type name mismatch;\n%s", diff) } - if diff := cmp.Diff(tt.pkg, tt.override.GoPackage); diff != "" { - t.Errorf("package mismatch;\n%s", diff) - } if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" { t.Errorf("basic mismatch;\n%s", diff) } @@ -124,20 +124,20 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "Pointer", + GoType: GoType{Spec: "Pointer"}, }, "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'", }, { Override{ DBType: "uuid", - GoType: "untyped rune", + GoType: GoType{Spec: "untyped rune"}, }, "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'", }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoType.Spec, func(t *testing.T) { err := tt.override.Parse() if err == nil { t.Fatalf("expected pars to fail; got nil") diff --git a/internal/config/go_type.go b/internal/config/go_type.go index 75f3b29fed..37d955d00e 100644 --- a/internal/config/go_type.go +++ b/internal/config/go_type.go @@ -1,25 +1,87 @@ package config import ( + "encoding/json" "fmt" "go/types" "strings" ) type GoType struct { - Path string - Package string - Name string - Pointer bool + Path string `json:"import" yaml:"import"` + Package string `json:"package" yaml:"package"` + Name string `json:"type" yaml:"type"` + Pointer bool `json:"pointer" yaml:"pointer"` + Spec string BuiltIn bool } -func ParseGoType(input string) (*GoType, error) { - // validate GoType +type ParsedGoType struct { + ImportPath string + Package string + TypeName string + BasicType bool +} + +func (o *GoType) UnmarshalJSON(data []byte) error { + var spec string + if err := json.Unmarshal(data, &spec); err == nil { + *o = GoType{Spec: spec} + return nil + } + type alias GoType + var a alias + if err := json.Unmarshal(data, &a); err != nil { + return err + } + *o = GoType(a) + return nil +} + +func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var spec string + if err := unmarshal(&spec); err == nil { + *o = GoType{Spec: spec} + return nil + } + type alias GoType + var a alias + if err := unmarshal(&a); err != nil { + return err + } + *o = GoType(a) + return nil +} + +// validate GoType +func (gt GoType) Parse() (*ParsedGoType, error) { + var o ParsedGoType + + if gt.Spec == "" { + // TODO: Validation + if gt.Path != "" && gt.Package == "" { + return nil, fmt.Errorf("Package override `go_type`: package name required when using an import path") + } + if gt.Path == "" && gt.Package != "" { + return nil, fmt.Errorf("Package override `go_type`: package name requires an import path") + } + o.ImportPath = gt.Path + o.Package = gt.Package + o.TypeName = gt.Name + o.BasicType = gt.Path == "" && gt.Package == "" + if gt.Package != "" { + o.TypeName = gt.Package + "." + o.TypeName + } + if gt.Pointer { + o.TypeName = "*" + o.TypeName + } + return &o, nil + } + + input := gt.Spec lastDot := strings.LastIndex(input, ".") lastSlash := strings.LastIndex(input, "/") typename := input - var o GoType if lastDot == -1 && lastSlash == -1 { // if the type name has no slash and no dot, validate that the type is a basic Go type var found bool @@ -38,7 +100,7 @@ func ParseGoType(input string) (*GoType, error) { if !found { return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input) } - o.BuiltIn = true + o.BasicType = true } else { // assume the type lives in a Go package if lastDot == -1 { @@ -58,13 +120,13 @@ func ParseGoType(input string) (*GoType, error) { if strings.HasSuffix(typename, "-go") { typename = typename[:len(typename)-len("-go")] } - o.Path = input[:lastDot] + o.ImportPath = input[:lastDot] } - o.Name = typename + o.TypeName = typename isPointer := input[0] == '*' if isPointer { - o.Path = o.Path[1:] - o.Name = "*" + o.Name + o.ImportPath = o.ImportPath[1:] + o.TypeName = "*" + o.TypeName } return &o, nil } diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go b/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go new file mode 100644 index 0000000000..6f21c94ed1 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go new file mode 100644 index 0000000000..9b987955d3 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "github.com/kyleconroy/sqlc-testdata/pkg" +) + +type Foo struct { + Other string + Total int64 + Retyped pkg.CustomType +} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/query.sql b/internal/endtoend/testdata/overrides_go_types/mysql/query.sql new file mode 100644 index 0000000000..e0ac49d1ec --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/query.sql @@ -0,0 +1 @@ +SELECT 1; diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql new file mode 100644 index 0000000000..c0c5fc47dc --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE foo ( + other text NOT NULL, + total bigint NOT NULL, + retyped text NOT NULL +); diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json new file mode 100644 index 0000000000..23cd7caffd --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json @@ -0,0 +1,18 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "override", + "engine": "mysql:beta", + "schema": "schema.sql", + "queries": "query.sql", + "overrides": [ + { + "go_type": "github.com/kyleconroy/sqlc-testdata/pkg.CustomType", + "column": "foo.retyped" + } + ] + } + ] +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go new file mode 100644 index 0000000000..6f21c94ed1 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go new file mode 100644 index 0000000000..f9a66d547c --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go @@ -0,0 +1,12 @@ +// Code generated by sqlc. DO NOT EDIT. + +package override + +import ( + "github.com/gofrs/uuid" +) + +type Foo struct { + ID uuid.UUID + About *string +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go new file mode 100644 index 0000000000..7ce6dcf3d9 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package override + +import ( + "context" + + "github.com/gofrs/uuid" +) + +const loadFoo = `-- name: LoadFoo :many +SELECT id, about FROM foo WHERE id = $1 +` + +func (q *Queries) LoadFoo(ctx context.Context, id uuid.UUID) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, loadFoo, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Foo + for rows.Next() { + var i Foo + if err := rows.Scan(&i.ID, &i.About); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql new file mode 100644 index 0000000000..192dc6cb94 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql @@ -0,0 +1,2 @@ +-- name: LoadFoo :many +SELECT * FROM foo WHERE id = $1; diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql new file mode 100644 index 0000000000..c319d9a5b8 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE foo ( + id uuid NOT NULL, + about text +); diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json new file mode 100644 index 0000000000..ca991a7d54 --- /dev/null +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json @@ -0,0 +1,30 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "override", + "engine": "postgresql", + "schema": "schema.sql", + "queries": "query.sql", + "overrides": [ + { + "db_type": "uuid", + "go_type": { + "import": "github.com/gofrs/uuid", + "package": "uuid", + "type": "UUID" + }, + }, + { + "nullable": true, + "db_type": "text", + "go_type": { + "type": "string", + "pointer": true + } + } + ] + } + ] +} From 14a7e0f828dab2acfc79ef23eb9d4e50ee075863 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 14 Nov 2020 18:37:49 -0800 Subject: [PATCH 3/5] Handle edge cases --- internal/codegen/golang/gen.go | 8 +- internal/codegen/golang/imports.go | 155 +++++++++++------- internal/config/go_type.go | 26 ++- .../postgresql/go/models.go | 19 ++- .../postgresql/go/query.sql.go | 46 +++++- .../overrides_go_types/postgresql/query.sql | 8 +- .../overrides_go_types/postgresql/schema.sql | 15 +- .../overrides_go_types/postgresql/sqlc.json | 44 ++++- 8 files changed, 237 insertions(+), 84 deletions(-) diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 81ea486f5c..4aa38d8cb4 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -26,7 +26,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -137,7 +137,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -175,7 +175,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) @@ -226,7 +226,7 @@ package {{.Package}} import ( {{range imports .SourceName}} - {{range .}}"{{.}}" + {{range .}}{{.}} {{end}} {{end}} ) diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 7db77f922a..63900ba763 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -1,6 +1,7 @@ package golang import ( + "fmt" "sort" "strings" @@ -9,35 +10,51 @@ import ( ) type fileImports struct { - Std []string - Dep []string + Std []ImportSpec + Dep []ImportSpec } -func mergeImports(imps ...fileImports) [][]string { +type ImportSpec struct { + ID string + Path string +} + +func (s ImportSpec) String() string { + if s.ID != "" { + return fmt.Sprintf("%s \"%s\"", s.ID, s.Path) + } else { + return fmt.Sprintf("\"%s\"", s.Path) + } +} + +func mergeImports(imps ...fileImports) [][]ImportSpec { if len(imps) == 1 { - return [][]string{imps[0].Std, imps[0].Dep} + return [][]ImportSpec{ + imps[0].Std, + imps[0].Dep, + } } - var stds, pkgs []string + var stds, pkgs []ImportSpec seenStd := map[string]struct{}{} seenPkg := map[string]struct{}{} for i := range imps { - for _, std := range imps[i].Std { - if _, ok := seenStd[std]; ok { + for _, spec := range imps[i].Std { + if _, ok := seenStd[spec.Path]; ok { continue } - stds = append(stds, std) - seenStd[std] = struct{}{} + stds = append(stds, spec) + seenStd[spec.Path] = struct{}{} } - for _, pkg := range imps[i].Dep { - if _, ok := seenPkg[pkg]; ok { + for _, spec := range imps[i].Dep { + if _, ok := seenPkg[spec.Path]; ok { continue } - pkgs = append(pkgs, pkg) - seenPkg[pkg] = struct{}{} + pkgs = append(pkgs, spec) + seenPkg[spec.Path] = struct{}{} } } - return [][]string{stds, pkgs} + return [][]ImportSpec{stds, pkgs} } type importer struct { @@ -70,7 +87,7 @@ func (i *importer) usesArrays() bool { return false } -func (i *importer) Imports(filename string) [][]string { +func (i *importer) Imports(filename string) [][]ImportSpec { switch filename { case "db.go": return mergeImports(i.dbImports()) @@ -84,9 +101,12 @@ func (i *importer) Imports(filename string) [][]string { } func (i *importer) dbImports() fileImports { - std := []string{"context", "database/sql"} + std := []ImportSpec{ + {Path: "context"}, + {Path: "database/sql"}, + } if i.Settings.Go.EmitPreparedQueries { - std = append(std, "fmt") + std = append(std, ImportSpec{Path: "fmt"}) } return fileImports{Std: std} } @@ -132,7 +152,7 @@ func (i *importer) interfaceImports() fileImports { std["net"] = struct{}{} } - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { @@ -143,32 +163,37 @@ func (i *importer) interfaceImports() fileImports { _, overrideNullTime := overrideTypes["pq.NullTime"] if uses("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if uses("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } // Custom imports - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && uses(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } @@ -194,7 +219,7 @@ func (i *importer) modelImports() fileImports { } // Custom imports - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { @@ -205,32 +230,37 @@ func (i *importer) modelImports() fileImports { _, overrideNullTime := overrideTypes["pq.NullTime"] if i.usesType("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if i.usesType("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && i.usesType(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && i.usesType(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } @@ -327,7 +357,7 @@ func (i *importer) queryImports(filename string) fileImports { std["net"] = struct{}{} } - pkg := make(map[string]struct{}) + pkg := make(map[ImportSpec]struct{}) overrideTypes := map[string]string{} for _, o := range i.Settings.Overrides { if o.GoBasicType { @@ -337,35 +367,40 @@ func (i *importer) queryImports(filename string) fileImports { } if sliceScan() { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideNullTime := overrideTypes["pq.NullTime"] if uses("pq.NullTime") && !overrideNullTime { - pkg["github.com/lib/pq"] = struct{}{} + pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } _, overrideUUID := overrideTypes["uuid.UUID"] if uses("uuid.UUID") && !overrideUUID { - pkg["github.com/google/uuid"] = struct{}{} + pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } // Custom imports - for goType, importPath := range overrideTypes { - if _, ok := std[importPath]; !ok && uses(goType) { - pkg[importPath] = struct{}{} + for _, o := range i.Settings.Overrides { + if o.GoBasicType { + continue + } + _, alreadyImported := std[o.GoImportPath] + hasPackageAlias := o.GoPackage != "" + if (!alreadyImported || hasPackageAlias) && uses(o.GoTypeName) { + pkg[ImportSpec{Path: o.GoImportPath, ID: o.GoPackage}] = struct{}{} } } - pkgs := make([]string, 0, len(pkg)) - for p, _ := range pkg { - pkgs = append(pkgs, p) + pkgs := make([]ImportSpec, 0, len(pkg)) + for spec, _ := range pkg { + pkgs = append(pkgs, spec) } - stds := make([]string, 0, len(std)) - for s, _ := range std { - stds = append(stds, s) + stds := make([]ImportSpec, 0, len(std)) + for path, _ := range std { + stds = append(stds, ImportSpec{Path: path}) } - sort.Strings(stds) - sort.Strings(pkgs) + sort.Slice(stds, func(i, j int) bool { return stds[i].Path < stds[j].Path }) + sort.Slice(pkgs, func(i, j int) bool { return pkgs[i].Path < pkgs[j].Path }) return fileImports{stds, pkgs} } diff --git a/internal/config/go_type.go b/internal/config/go_type.go index 37d955d00e..b7ecb581ae 100644 --- a/internal/config/go_type.go +++ b/internal/config/go_type.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "go/types" + "regexp" "strings" ) @@ -53,18 +54,37 @@ func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } +var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) +var versionNumber = regexp.MustCompile(`^v[0-9]+$`) +var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`) + +func generatePackageID(importPath string) string { + parts := strings.Split(importPath, "/") + name := parts[len(parts)-1] + fmt.Println("parts", parts) + // If the last part of the import path is a valid identifier, assume that's the package name + if versionNumber.MatchString(name) && len(parts) >= 2 { + name = parts[len(parts)-2] + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_") + } + if validIdentifier.MatchString(name) { + return "" + } + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_") +} + // validate GoType func (gt GoType) Parse() (*ParsedGoType, error) { var o ParsedGoType if gt.Spec == "" { // TODO: Validation - if gt.Path != "" && gt.Package == "" { - return nil, fmt.Errorf("Package override `go_type`: package name required when using an import path") - } if gt.Path == "" && gt.Package != "" { return nil, fmt.Errorf("Package override `go_type`: package name requires an import path") } + if gt.Package == "" && gt.Path != "" { + gt.Package = generatePackageID(gt.Path) + } o.ImportPath = gt.Path o.Package = gt.Package o.TypeName = gt.Name diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go index f9a66d547c..d5d96fb69c 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go @@ -3,10 +3,25 @@ package override import ( + "database/sql" + + orm "database/sql" "github.com/gofrs/uuid" + fuid "github.com/gofrs/uuid" + null "github.com/guregu/null/v4" + null_v4 "gopkg.in/guregu/null.v4" ) -type Foo struct { +type NewStyle struct { + ID UUID + OtherID fuid.UUID + Age orm.NullInt32 + Balance null.Float + Bio null_v4.String + About *string +} + +type OldStyle struct { ID uuid.UUID - About *string + About sql.NullString } diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go index 7ce6dcf3d9..a9eba5243e 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go @@ -9,19 +9,53 @@ import ( "github.com/gofrs/uuid" ) -const loadFoo = `-- name: LoadFoo :many -SELECT id, about FROM foo WHERE id = $1 +const loadNewStyle = `-- name: LoadNewStyle :many +SELECT id, other_id, age, balance, bio, about FROM new_style WHERE id = $1 ` -func (q *Queries) LoadFoo(ctx context.Context, id uuid.UUID) ([]Foo, error) { - rows, err := q.db.QueryContext(ctx, loadFoo, id) +func (q *Queries) LoadNewStyle(ctx context.Context, id UUID) ([]NewStyle, error) { + rows, err := q.db.QueryContext(ctx, loadNewStyle, id) if err != nil { return nil, err } defer rows.Close() - var items []Foo + var items []NewStyle for rows.Next() { - var i Foo + var i NewStyle + if err := rows.Scan( + &i.ID, + &i.OtherID, + &i.Age, + &i.Balance, + &i.Bio, + &i.About, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const loadOldStyle = `-- name: LoadOldStyle :many +SELECT id, about FROM old_style WHERE id = $1 +` + +func (q *Queries) LoadOldStyle(ctx context.Context, id uuid.UUID) ([]OldStyle, error) { + rows, err := q.db.QueryContext(ctx, loadOldStyle, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OldStyle + for rows.Next() { + var i OldStyle if err := rows.Scan(&i.ID, &i.About); err != nil { return nil, err } diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql index 192dc6cb94..bd1ac2c272 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql @@ -1,2 +1,6 @@ --- name: LoadFoo :many -SELECT * FROM foo WHERE id = $1; +-- name: LoadNewStyle :many +SELECT * FROM new_style WHERE id = $1; + +-- name: LoadOldStyle :many +SELECT * FROM old_style WHERE id = $1; + diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql index c319d9a5b8..9dbc60fb6b 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql @@ -1,4 +1,13 @@ -CREATE TABLE foo ( - id uuid NOT NULL, - about text +CREATE TABLE new_style ( + id uuid NOT NULL, + other_id uuid NOT NULL, + age integer, + balance double, + bio text, + about text +); + +CREATE TABLE old_style ( + id uuid NOT NULL, + about text ); diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json index ca991a7d54..422c40fac0 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json @@ -9,21 +9,57 @@ "queries": "query.sql", "overrides": [ { - "db_type": "uuid", + "column": "new_style.id", "go_type": { "import": "github.com/gofrs/uuid", - "package": "uuid", "type": "UUID" }, }, { + "column": "new_style.other_id", + "go_type": { + "import": "github.com/gofrs/uuid", + "package": "fuid", + "type": "UUID" + }, + }, + { + "column": "new_style.age", + "nullable": true, + "go_type": { + "import": "database/sql", + "package": "orm", + "type": "NullInt32" + }, + }, + { + "column": "new_style.balance", + "nullable": true, + "go_type": { + "import": "github.com/guregu/null/v4", + "type": "Float" + }, + }, + { + "column": "new_style.bio", + "nullable": true, + "go_type": { + "import": "gopkg.in/guregu/null.v4", + "type": "String" + }, + }, + { + "column": "new_style.about", "nullable": true, - "db_type": "text", "go_type": { "type": "string", "pointer": true } - } + }, + { + "column": "old_style.id", + "go_type": "github.com/gofrs/uuid.UUID", + }, ] } ] From 99a2c7260f989b86dd93287a8f0b544d81a798bd Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 14 Nov 2020 18:40:17 -0800 Subject: [PATCH 4/5] remove println stmt --- internal/config/go_type.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/config/go_type.go b/internal/config/go_type.go index b7ecb581ae..19d4d577c2 100644 --- a/internal/config/go_type.go +++ b/internal/config/go_type.go @@ -61,7 +61,6 @@ var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`) func generatePackageID(importPath string) string { parts := strings.Split(importPath, "/") name := parts[len(parts)-1] - fmt.Println("parts", parts) // If the last part of the import path is a valid identifier, assume that's the package name if versionNumber.MatchString(name) && len(parts) >= 2 { name = parts[len(parts)-2] From 1a5d404a2c2658fabbb45900c9db0cfe01cf6677 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 14 Nov 2020 20:00:31 -0800 Subject: [PATCH 5/5] Fix imports for new-style overrides --- internal/config/go_type.go | 25 ++++++++---- .../postgresql/go/models.go | 15 ++----- .../postgresql/go/query.sql.go | 39 +++---------------- .../overrides_go_types/postgresql/query.sql | 8 +--- .../overrides_go_types/postgresql/schema.sql | 7 +--- .../overrides_go_types/postgresql/sqlc.json | 22 +++++------ 6 files changed, 39 insertions(+), 77 deletions(-) diff --git a/internal/config/go_type.go b/internal/config/go_type.go index 19d4d577c2..8a2b3fa06a 100644 --- a/internal/config/go_type.go +++ b/internal/config/go_type.go @@ -58,18 +58,18 @@ var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) var versionNumber = regexp.MustCompile(`^v[0-9]+$`) var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`) -func generatePackageID(importPath string) string { +func generatePackageID(importPath string) (string, bool) { parts := strings.Split(importPath, "/") name := parts[len(parts)-1] // If the last part of the import path is a valid identifier, assume that's the package name if versionNumber.MatchString(name) && len(parts) >= 2 { name = parts[len(parts)-2] - return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_") + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true } if validIdentifier.MatchString(name) { - return "" + return name, false } - return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_") + return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true } // validate GoType @@ -81,15 +81,24 @@ func (gt GoType) Parse() (*ParsedGoType, error) { if gt.Path == "" && gt.Package != "" { return nil, fmt.Errorf("Package override `go_type`: package name requires an import path") } + var pkg string + var pkgNeedsAlias bool + if gt.Package == "" && gt.Path != "" { - gt.Package = generatePackageID(gt.Path) + pkg, pkgNeedsAlias = generatePackageID(gt.Path) + if pkgNeedsAlias { + o.Package = pkg + } + } else { + pkg = gt.Package + o.Package = gt.Package } + o.ImportPath = gt.Path - o.Package = gt.Package o.TypeName = gt.Name o.BasicType = gt.Path == "" && gt.Package == "" - if gt.Package != "" { - o.TypeName = gt.Package + "." + o.TypeName + if pkg != "" { + o.TypeName = pkg + "." + o.TypeName } if gt.Pointer { o.TypeName = "*" + o.TypeName diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go index d5d96fb69c..6a07b893c1 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/models.go @@ -3,25 +3,18 @@ package override import ( - "database/sql" - orm "database/sql" "github.com/gofrs/uuid" fuid "github.com/gofrs/uuid" - null "github.com/guregu/null/v4" + null "github.com/volatiletech/null/v8" null_v4 "gopkg.in/guregu/null.v4" ) -type NewStyle struct { - ID UUID +type Foo struct { + ID uuid.UUID OtherID fuid.UUID Age orm.NullInt32 - Balance null.Float + Balance null.Float32 Bio null_v4.String About *string } - -type OldStyle struct { - ID uuid.UUID - About sql.NullString -} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go index a9eba5243e..75e73b73d0 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/go/query.sql.go @@ -9,19 +9,19 @@ import ( "github.com/gofrs/uuid" ) -const loadNewStyle = `-- name: LoadNewStyle :many -SELECT id, other_id, age, balance, bio, about FROM new_style WHERE id = $1 +const loadFoo = `-- name: LoadFoo :many +SELECT id, other_id, age, balance, bio, about FROM foo WHERE id = $1 ` -func (q *Queries) LoadNewStyle(ctx context.Context, id UUID) ([]NewStyle, error) { - rows, err := q.db.QueryContext(ctx, loadNewStyle, id) +func (q *Queries) LoadFoo(ctx context.Context, id uuid.UUID) ([]Foo, error) { + rows, err := q.db.QueryContext(ctx, loadFoo, id) if err != nil { return nil, err } defer rows.Close() - var items []NewStyle + var items []Foo for rows.Next() { - var i NewStyle + var i Foo if err := rows.Scan( &i.ID, &i.OtherID, @@ -42,30 +42,3 @@ func (q *Queries) LoadNewStyle(ctx context.Context, id UUID) ([]NewStyle, error) } return items, nil } - -const loadOldStyle = `-- name: LoadOldStyle :many -SELECT id, about FROM old_style WHERE id = $1 -` - -func (q *Queries) LoadOldStyle(ctx context.Context, id uuid.UUID) ([]OldStyle, error) { - rows, err := q.db.QueryContext(ctx, loadOldStyle, id) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OldStyle - for rows.Next() { - var i OldStyle - if err := rows.Scan(&i.ID, &i.About); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql index bd1ac2c272..192dc6cb94 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/query.sql @@ -1,6 +1,2 @@ --- name: LoadNewStyle :many -SELECT * FROM new_style WHERE id = $1; - --- name: LoadOldStyle :many -SELECT * FROM old_style WHERE id = $1; - +-- name: LoadFoo :many +SELECT * FROM foo WHERE id = $1; diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql index 9dbc60fb6b..4e0d1f5af7 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/schema.sql @@ -1,4 +1,4 @@ -CREATE TABLE new_style ( +CREATE TABLE foo ( id uuid NOT NULL, other_id uuid NOT NULL, age integer, @@ -6,8 +6,3 @@ CREATE TABLE new_style ( bio text, about text ); - -CREATE TABLE old_style ( - id uuid NOT NULL, - about text -); diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json index 422c40fac0..8d3b0c2223 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/sqlc.json @@ -9,14 +9,14 @@ "queries": "query.sql", "overrides": [ { - "column": "new_style.id", + "column": "foo.id", "go_type": { "import": "github.com/gofrs/uuid", "type": "UUID" }, }, { - "column": "new_style.other_id", + "column": "foo.other_id", "go_type": { "import": "github.com/gofrs/uuid", "package": "fuid", @@ -24,7 +24,7 @@ }, }, { - "column": "new_style.age", + "column": "foo.age", "nullable": true, "go_type": { "import": "database/sql", @@ -33,15 +33,15 @@ }, }, { - "column": "new_style.balance", + "column": "foo.balance", "nullable": true, "go_type": { - "import": "github.com/guregu/null/v4", - "type": "Float" + "import": "github.com/volatiletech/null/v8", + "type": "Float32" }, }, { - "column": "new_style.bio", + "column": "foo.bio", "nullable": true, "go_type": { "import": "gopkg.in/guregu/null.v4", @@ -49,17 +49,13 @@ }, }, { - "column": "new_style.about", + "column": "foo.about", "nullable": true, "go_type": { "type": "string", "pointer": true } - }, - { - "column": "old_style.id", - "go_type": "github.com/gofrs/uuid.UUID", - }, + } ] } ]