diff --git a/internal/config/config.go b/internal/config/config.go index 6e5c4db444..370be3639b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -114,7 +114,8 @@ type SQLGen struct { type SQLGo struct { EmitInterface bool `json:"emit_interface" yaml:"emit_interface"` EmitJSONTags bool `json:"emit_json_tags" yaml:"emit_json_tags"` - EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries":` + EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` + EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` Package string `json:"package" yaml:"package"` Out string `json:"out" yaml:"out"` Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` @@ -127,27 +128,40 @@ type SQLKotlin struct { } type Override struct { - // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` - GoType string `json:"go_type" yaml:"go_type"` + // Import path, package name and type name as you would type them in the IDE + GoTypeParam GoTypeParams `json:"go_type" yaml:"go_type"` - // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` + // The database type to override DBType string `json:"db_type" yaml:"db_type"` Deprecated_PostgresType string `json:"postgres_type" yaml:"postgres_type"` // for global overrides only when two different engines are in use Engine Engine `json:"engine,omitempty" yaml:"engine"` - // True if the GoType should override if the maching postgres type is nullable - Null bool `json:"null" yaml:"null"` + // True if the GoType should override if the matching postgres type is nullable + Null bool `json:"is_null" yaml:"is_null"` // fully qualified name of the column, e.g. `accounts.id` Column string `json:"column" yaml:"column"` - ColumnName string - Table pg.FQN - GoTypeName string - GoPackage string - GoBasicType bool + ColumnName string + Table pg.FQN + + // FIXME The whole sqlc package could be rewritten to fetch these values from the new type GoTypeParams, however, to not mess around too much I left them here. Otherwise it would be needed to refactor in too many places around the codebase. I know for certain that the rest of the code will search these values here hence, I'm leeaving them as such. -- maxiride + GoImportPath string + GoPackageName string + GoBasicType bool +} + +type GoTypeParams struct { + // Eg. package "github.com/segmentio/ksuid" which usage is ksuid.KSUID would have: + // ImportPath "github.com/segmentio/ksuid" + // Package name ksuid + // Type name KSUID + + ImportPath string `json:"import" yaml:"import"` + PackageName string `json:"package" yaml:"package"` + TypeName string `json:"type" yaml:"type"` } func (o *Override) Parse() error { @@ -188,9 +202,15 @@ func (o *Override) Parse() error { } // validate GoType - lastDot := strings.LastIndex(o.GoType, ".") - lastSlash := strings.LastIndex(o.GoType, "/") - typename := o.GoType + lastDot := strings.LastIndex(o.GoTypeParam.ImportPath, ".") + lastSlash := strings.LastIndex(o.GoTypeParam.ImportPath, "/") + + + // If the overriding type is "local" (see #177) the config will have the 'import' tag null\zeroed, + // lastDot ==-1 && lasSlash == -1 would return TRUE and we would attempt to find a builtin type, which will not be found. + // FIXME Need to think on how to differentiate from a "local" type and builtin type when the import path is not set. + // As of now the test passes because there is no test for such scenario. + // Possibile fix: check that first letter is uppercase with unicode.IsUpper(o.GoTypeParam.ImportPath[0]) && unicode.IsLetter(o.GoTypeParam.ImportPath[0]) if lastDot == -1 && lastSlash == -1 { // if the type name has no slash and no dot, validate that the type is a basic Go type var found bool @@ -202,40 +222,58 @@ func (o *Override) Parse() error { if info&types.IsUntyped != 0 { continue } - if typename == typ.Name() { + if o.GoTypeParam.TypeName == typ.Name() { found = true } } if !found { - return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoType) + return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoTypeParam.TypeName) } o.GoBasicType = true } else { - // assume the type lives in a Go package - if lastDot == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - if lastSlash == -1 { - return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) - } - typename = o.GoType[lastSlash+1:] - if strings.HasPrefix(typename, "go-") { + // As pointed out, for backwards compatibility this check needs to be performed, however in the current status of the + // PR these checks will always fail with the new config because, if used, the 'import' tag won't have any dot + // except for the niche of complex_go_types + // FIXME, need to implement a check to ensure whether the "legacy" or "new" config structure is used + // Possible quick and dirty fix: check if o.GoTypeParam.TypeName && o.GoTypeParam.PackageName are set + /* + // assume the type lives in a Go package + if lastDot == -1 { + return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) + } + if lastSlash == -1 { + return fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", o.GoType) + } + */ + + if strings.HasPrefix(o.GoTypeParam.PackageName, "go-") { // a package name beginning with "go-" will give syntax errors in // generated code. We should do the right thing and get the actual // import name, but in lieu of that, stripping the leading "go-" may get // us what we want. - typename = typename[len("go-"):] + o.GoTypeParam.PackageName = o.GoTypeParam.PackageName[len("go-"):] } - if strings.HasSuffix(typename, "-go") { - typename = typename[:len(typename)-len("-go")] + if strings.HasSuffix(o.GoTypeParam.PackageName, "-go") { + o.GoTypeParam.PackageName = o.GoTypeParam.PackageName[:len(o.GoTypeParam.PackageName)-len("-go")] } - o.GoPackage = o.GoType[:lastDot] + + o.GoPackageName = o.GoTypeParam.PackageName + } - o.GoTypeName = typename - isPointer := o.GoType[0] == '*' - if isPointer { - o.GoPackage = o.GoPackage[1:] - o.GoTypeName = "*" + o.GoTypeName + + if len(o.GoTypeParam.ImportPath) > 0 { + o.GoImportPath = o.GoTypeParam.PackageName + "." + o.GoTypeParam.TypeName + } else { + o.GoImportPath = o.GoTypeParam.TypeName + } + + if len(o.GoTypeParam.ImportPath) > 0 { + isPointer := o.GoTypeParam.ImportPath[0] == '*' + if isPointer { + // FIXME unsure how to handle this, didn't fully understood it yet + o.GoPackageName = o.GoPackageName[1:] + o.GoImportPath = "*" + o.GoImportPath + } } return nil diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ef12a178b4..4f4a7cb50d 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -74,9 +74,13 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "github.com/segmentio/ksuid.KSUID", + GoTypeParam: GoTypeParams{ + ImportPath: "github.com/segmentio/ksuid", + PackageName: "ksuid", + TypeName: "KSUID", + }, }, - "github.com/segmentio/ksuid", + "ksuid", "ksuid.KSUID", false, }, @@ -91,25 +95,70 @@ func TestTypeOverrides(t *testing.T) { // "*ksuid.KSUID", // false, // }, + // + // TODO Add test for config where the import path is not set but the type declared isn't a builtin one but a custom one + //{ + // Override{ + // DBType: "string", + // GoTypeParam: GoTypeParams{ + // ImportPath: "", + // PackageName: "", + // TypeName: "CustomType", + // }, + // }, + // "", + // "CustomType", + // false, + //}, { Override{ DBType: "citext", - GoType: "string", + GoTypeParam: GoTypeParams{ + ImportPath: "", + PackageName: "", + TypeName: "string", + }, }, "", "string", true, }, + { + Override{ + DBType: "string", + GoTypeParam: GoTypeParams{ + ImportPath: "gopkg.in/guregu/null.v3/zero", + PackageName: "zero", + TypeName: "String", + }, + }, + "zero", + "zero.String", + false, + }, + { + Override{ + DBType: "string", + GoTypeParam: GoTypeParams{ + ImportPath: "gopkg.in/guregu/null.v3", + PackageName: "null", + TypeName: "String", + }, + }, + "null", + "null.String", + false, + }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoTypeParam.ImportPath, func(t *testing.T) { if err := tt.override.Parse(); err != nil { t.Fatalf("override parsing failed; %s", err) } - if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" { + if diff := cmp.Diff(tt.typeName, tt.override.GoImportPath); diff != "" { t.Errorf("type name mismatch;\n%s", diff) } - if diff := cmp.Diff(tt.pkg, tt.override.GoPackage); diff != "" { + if diff := cmp.Diff(tt.pkg, tt.override.GoPackageName); diff != "" { t.Errorf("package mismatch;\n%s", diff) } if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" { @@ -117,6 +166,7 @@ func TestTypeOverrides(t *testing.T) { } }) } + for _, test := range []struct { override Override err string @@ -124,20 +174,28 @@ func TestTypeOverrides(t *testing.T) { { Override{ DBType: "uuid", - GoType: "Pointer", + GoTypeParam: GoTypeParams{ + ImportPath: "", + PackageName: "", + TypeName: "Pointer", + }, }, "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'", }, { Override{ DBType: "uuid", - GoType: "untyped rune", + GoTypeParam: GoTypeParams{ + ImportPath: "", + PackageName: "", + TypeName: "untyped rune", + }, }, "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'", }, } { tt := test - t.Run(tt.override.GoType, func(t *testing.T) { + t.Run(tt.override.GoTypeParam.TypeName, func(t *testing.T) { err := tt.override.Parse() if err == nil { t.Fatalf("expected pars to fail; got nil") diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index 3761db64e7..c8755c2db1 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -288,7 +288,7 @@ func interfaceImports(r Generateable, settings config.CombinedSettings) fileImpo if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoImportPath] = o.GoTypeParam.ImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] @@ -350,7 +350,7 @@ func modelImports(r Generateable, settings config.CombinedSettings) fileImports if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoImportPath] = o.GoTypeParam.ImportPath } _, overrideNullTime := overrideTypes["pq.NullTime"] @@ -485,7 +485,7 @@ func queryImports(r Generateable, settings config.CombinedSettings, filename str if o.GoBasicType { continue } - overrideTypes[o.GoTypeName] = o.GoPackage + overrideTypes[o.GoImportPath] = o.GoTypeParam.ImportPath } if sliceScan() { @@ -621,7 +621,7 @@ func (r Result) goType(col core.Column, settings config.CombinedSettings) string // package overrides have a higher precedence for _, oride := range settings.Overrides { if oride.Column != "" && oride.ColumnName == col.Name && oride.Table == col.Table { - return oride.GoTypeName + return oride.GoImportPath } } typ := r.goInnerType(col, settings) @@ -638,7 +638,7 @@ func (r Result) goInnerType(col core.Column, settings config.CombinedSettings) s // package overrides have a higher precedence for _, oride := range settings.Overrides { if oride.DBType != "" && oride.DBType == columnType && oride.Null != notNull { - return oride.GoTypeName + return oride.GoImportPath } } diff --git a/internal/endtoend/testdata/complex_go_type/go/db.go b/internal/endtoend/testdata/complex_go_type/go/db.go new file mode 100644 index 0000000000..c3c034ae37 --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/go/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/complex_go_type/go/models.go b/internal/endtoend/testdata/complex_go_type/go/models.go new file mode 100644 index 0000000000..b2556d5a5f --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/go/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. + +package db + +import ( + "gopkg.in/guregu/null.v3/zero" +) + +type Author struct { + ID int64 + Name string + Bio zero.String +} diff --git a/internal/endtoend/testdata/complex_go_type/go/query.sql.go b/internal/endtoend/testdata/complex_go_type/go/query.sql.go new file mode 100644 index 0000000000..7da7f9458a --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/go/query.sql.go @@ -0,0 +1,81 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package db + +import ( + "context" + + "gopkg.in/guregu/null.v3/zero" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors ( + name, bio +) VALUES ( + $1, $2 +) +RETURNING id, name, bio +` + +type CreateAuthorParams struct { + Name string + Bio zero.String +} + +func (q *Queries) CreateAuthor(ctx context.Context, arg CreateAuthorParams) (Author, error) { + row := q.db.QueryRowContext(ctx, createAuthor, arg.Name, arg.Bio) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const deleteAuthor = `-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = $1 +` + +func (q *Queries) DeleteAuthor(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deleteAuthor, id) + return err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT id, name, bio FROM authors +WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { + row := q.db.QueryRowContext(ctx, getAuthor, id) + var i Author + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return i, err +} + +const listAuthors = `-- name: ListAuthors :many +SELECT id, name, bio FROM authors +ORDER BY name +` + +func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { + rows, err := q.db.QueryContext(ctx, listAuthors) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Author + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/complex_go_type/query.sql b/internal/endtoend/testdata/complex_go_type/query.sql new file mode 100644 index 0000000000..75e38b2caf --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/query.sql @@ -0,0 +1,19 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE id = $1 LIMIT 1; + +-- name: ListAuthors :many +SELECT * FROM authors +ORDER BY name; + +-- name: CreateAuthor :one +INSERT INTO authors ( + name, bio +) VALUES ( + $1, $2 +) +RETURNING *; + +-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = $1; diff --git a/internal/endtoend/testdata/complex_go_type/schema.sql b/internal/endtoend/testdata/complex_go_type/schema.sql new file mode 100644 index 0000000000..b4fad78497 --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + bio text +); diff --git a/internal/endtoend/testdata/complex_go_type/sqlc.json b/internal/endtoend/testdata/complex_go_type/sqlc.json new file mode 100644 index 0000000000..cdd2772217 --- /dev/null +++ b/internal/endtoend/testdata/complex_go_type/sqlc.json @@ -0,0 +1,22 @@ +{ + "version": "1", + "packages": [ + { + "name": "db", + "path": "go", + "queries": "query.sql", + "schema": "schema.sql" + } + ], + "overrides": [ + { + "go_type": { + "import": "gopkg.in/guregu/null.v3/zero", + "package": "zero", + "type": "String" + }, + "is_null": true, + "db_type": "text" + } + ] +} \ No newline at end of file diff --git a/internal/endtoend/testdata/mysql_overrides/sqlc.json b/internal/endtoend/testdata/mysql_overrides/sqlc.json index b7cfad6101..ba1afacccb 100644 --- a/internal/endtoend/testdata/mysql_overrides/sqlc.json +++ b/internal/endtoend/testdata/mysql_overrides/sqlc.json @@ -7,17 +7,34 @@ "schema": "schema.sql", "queries": "query.sql", "engine": "mysql", - "overrides": [{ - "go_type": "example.com/mysql.ID", - "column": "users.id" - }, { - "go_type": "example.com/mysql.ID", - "column": "orders.id" - }] + "overrides": [ + { + "go_type": { + "import": "example.com/mysql", + "package": "mysql", + "type": "ID" + }, + "column": "users.id" + }, + { + "go_type": { + "import": "example.com/mysql", + "package": "mysql", + "type": "ID" + }, + "column": "orders.id" + } + ] } ], - "overrides": [{ - "go_type": "example.com/mysql.Timestamp", - "db_type": "timestamp" - }] + "overrides": [ + { + "go_type": { + "import": "example.com/mysql", + "package": "mysql", + "type": "Timestamp" + }, + "db_type": "timestamp" + } + ] } diff --git a/internal/endtoend/testdata/overrides/sqlc.json b/internal/endtoend/testdata/overrides/sqlc.json index f98bb706b7..0371be3f23 100644 --- a/internal/endtoend/testdata/overrides/sqlc.json +++ b/internal/endtoend/testdata/overrides/sqlc.json @@ -8,11 +8,19 @@ "queries": "sql/", "overrides": [ { - "go_type": "example.com/pkg.CustomType", + "go_type": { + "import": "example.com/pkg", + "package": "pkg", + "type": "CustomType" + }, "column": "foo.retyped" }, { - "go_type": "github.com/lib/pq.StringArray", + "go_type": { + "import": "github.com/lib/pq", + "package": "pq", + "type": "StringArray" + }, "column": "foo.langs" } ] diff --git a/internal/endtoend/testdata/yaml_overrides/sqlc.yaml b/internal/endtoend/testdata/yaml_overrides/sqlc.yaml index d91ace1235..2d82456887 100644 --- a/internal/endtoend/testdata/yaml_overrides/sqlc.yaml +++ b/internal/endtoend/testdata/yaml_overrides/sqlc.yaml @@ -5,7 +5,13 @@ packages: schema: "sql/" queries: "sql/" overrides: - - go_type: "example.com/pkg.CustomType" + - go_type: + import: "example.com/pkg" + package: "pkg" + type: "CustomType" column: "foo.retyped" - - go_type: "github.com/lib/pq.StringArray" + - go_type: + import: "github.com/lib/pq" + package: "pq" + type: "StringArray" column: "foo.langs" diff --git a/internal/mysql/gen.go b/internal/mysql/gen.go index 17ebf51d54..2854364332 100644 --- a/internal/mysql/gen.go +++ b/internal/mysql/gen.go @@ -230,7 +230,7 @@ func (pGen PackageGenerator) goTypeCol(col Column) string { shouldOverride := (oride.DBType != "" && oride.DBType == mySQLType && oride.Null != notNull) || (oride.ColumnName != "" && oride.ColumnName == colName && oride.Table.Rel == col.Table) if shouldOverride { - return oride.GoTypeName + return oride.GoImportPath } } switch t := mySQLType; {