From 3b484a60b28314dfb13c3c597d2eb36cf59ab189 Mon Sep 17 00:00:00 2001 From: Miha Vrhovnik Date: Wed, 7 Oct 2020 13:31:03 +0200 Subject: [PATCH] Add option to use pointers instead of sql.Null* types --- internal/codegen/golang/mysql_type.go | 32 +++++---- internal/codegen/golang/postgresql_type.go | 77 ++++++++++++---------- internal/codegen/golang/sqlite_type.go | 12 ++-- internal/config/config.go | 1 + internal/config/v_one.go | 2 + 5 files changed, 71 insertions(+), 53 deletions(-) diff --git a/internal/codegen/golang/mysql_type.go b/internal/codegen/golang/mysql_type.go index 4be59fc7f6..9bf8f9ecdf 100644 --- a/internal/codegen/golang/mysql_type.go +++ b/internal/codegen/golang/mysql_type.go @@ -12,24 +12,28 @@ import ( func mysqlType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { columnType := col.DataType notNull := col.NotNull || col.IsArray + pointer := "" + if !notNull && settings.Go.UsePointers { + pointer = "*" + } switch columnType { case "varchar", "text", "char", "tinytext", "mediumtext", "longtext": - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" case "int", "integer", "smallint", "mediumint", "year": - if notNull { - return "int32" + if notNull || pointer != "" { + return pointer + "int32" } return "sql.NullInt32" case "bigint": - if notNull { - return "int64" + if notNull || pointer != "" { + return pointer + "int64" } return "sql.NullInt64" @@ -37,14 +41,14 @@ func mysqlType(r *compiler.Result, col *compiler.Column, settings config.Combine return "[]byte" case "double", "double precision", "real": - if notNull { - return "float64" + if notNull || pointer != "" { + return pointer + "float64" } return "sql.NullFloat64" case "decimal", "dec", "fixed": - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" @@ -53,14 +57,14 @@ func mysqlType(r *compiler.Result, col *compiler.Column, settings config.Combine return "string" case "date", "timestamp", "datetime", "time": - if notNull { - return "time.Time" + if notNull || pointer != "" { + return pointer + "time.Time" } return "sql.NullTime" case "boolean", "bool", "tinyint": - if notNull { - return "bool" + if notNull || pointer != "" { + return pointer + "bool" } return "sql.NullBool" diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 9c82bf3632..849a84d855 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -12,47 +12,51 @@ import ( func postgresType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { columnType := col.DataType notNull := col.NotNull || col.IsArray + pointer := "" + if !notNull && settings.Go.UsePointers { + pointer = "*" + } switch columnType { case "serial", "serial4", "pg_catalog.serial4": - if notNull { - return "int32" + if notNull || pointer != "" { + return pointer + "int32" } return "sql.NullInt32" case "bigserial", "serial8", "pg_catalog.serial8": - if notNull { - return "int64" + if notNull || pointer != "" { + return pointer + "int64" } return "sql.NullInt64" case "smallserial", "serial2", "pg_catalog.serial2": - return "int16" + return pointer + "int16" case "integer", "int", "int4", "pg_catalog.int4": - if notNull { - return "int32" + if notNull || pointer != "" { + return pointer + "int32" } return "sql.NullInt32" case "bigint", "int8", "pg_catalog.int8": - if notNull { - return "int64" + if notNull || pointer != "" { + return pointer + "int64" } return "sql.NullInt64" case "smallint", "int2", "pg_catalog.int2": - return "int16" + return pointer + "int16" case "float", "double precision", "float8", "pg_catalog.float8": - if notNull { - return "float64" + if notNull || pointer != "" { + return pointer + "float64" } return "sql.NullFloat64" case "real", "float4", "pg_catalog.float4": - if notNull { - return "float32" + if notNull || pointer != "" { + return pointer + "float32" } return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file @@ -61,14 +65,14 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb // returns numerics as strings. // // https://github.com/lib/pq/issues/648 - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" case "boolean", "bool", "pg_catalog.bool": - if notNull { - return "bool" + if notNull || pointer != "" { + return pointer + "bool" } return "sql.NullBool" @@ -79,37 +83,37 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb return "[]byte" case "date": - if notNull { - return "time.Time" + if notNull || pointer != "" { + return pointer + "time.Time" } return "sql.NullTime" case "pg_catalog.time", "pg_catalog.timetz": - if notNull { - return "time.Time" + if notNull || pointer != "" { + return pointer + "time.Time" } return "sql.NullTime" case "pg_catalog.timestamp", "pg_catalog.timestamptz", "timestamptz": - if notNull { - return "time.Time" + if notNull || pointer != "" { + return pointer + "time.Time" } return "sql.NullTime" case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string": - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" case "uuid": - return "uuid.UUID" + return pointer + "uuid.UUID" case "inet", "cidr": - return "net.IP" + return pointer + "net.IP" case "macaddr", "macaddr8": - return "net.HardwareAddr" + return pointer + "net.HardwareAddr" case "ltree", "lquery", "ltxtquery": // This module implements a data type ltree for representing labels @@ -117,18 +121,21 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb // facilities for searching through label trees are provided. // // https://www.postgresql.org/docs/current/ltree.html - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" case "interval", "pg_catalog.interval": - if notNull { - return "int64" + if notNull || pointer != "" { + return pointer + "int64" } return "sql.NullInt64" case "void": + if pointer != "" { + return "*bool" + } // A void value always returns NULL. Since there is no built-in NULL // value into the SQL package, we'll use sql.NullBool return "sql.NullBool" @@ -160,8 +167,8 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb return StructName(schema.Name+"_"+t.Name, settings) } case *catalog.CompositeType: - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" } diff --git a/internal/codegen/golang/sqlite_type.go b/internal/codegen/golang/sqlite_type.go index 505a74740c..ace6c0c25f 100644 --- a/internal/codegen/golang/sqlite_type.go +++ b/internal/codegen/golang/sqlite_type.go @@ -11,12 +11,16 @@ import ( func sqliteType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) string { dt := col.DataType notNull := col.NotNull || col.IsArray + pointer := "" + if !notNull && settings.Go.UsePointers { + pointer = "*" + } switch dt { case "integer": - if notNull { - return "int32" + if notNull || pointer != "" { + return pointer + "int32" } return "sql.NullInt32" @@ -28,8 +32,8 @@ func sqliteType(r *compiler.Result, col *compiler.Column, settings config.Combin switch { case strings.HasPrefix(dt, "varchar"): - if notNull { - return "string" + if notNull || pointer != "" { + return pointer + "string" } return "sql.NullString" diff --git a/internal/config/config.go b/internal/config/config.go index d2e7c2f987..fe960aa0fc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -117,6 +117,7 @@ type SQLGo struct { EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` + UsePointers bool `json:"use_pointers,omitempty" yaml:"use_pointers"` Package string `json:"package" yaml:"package"` Out string `json:"out" yaml:"out"` Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` diff --git a/internal/config/v_one.go b/internal/config/v_one.go index 55d6b3ae32..a10c761291 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -27,6 +27,7 @@ type v1PackageSettings struct { EmitPreparedQueries bool `json:"emit_prepared_queries" yaml:"emit_prepared_queries"` EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"` EmitEmptySlices bool `json:"emit_empty_slices,omitempty" yaml:"emit_empty_slices"` + UsePointers bool `json:"use_pointers,omitempty" yaml:"use_pointers"` Overrides []Override `json:"overrides" yaml:"overrides"` } @@ -109,6 +110,7 @@ func (c *V1GenerateSettings) Translate() Config { EmitPreparedQueries: pkg.EmitPreparedQueries, EmitExactTableNames: pkg.EmitExactTableNames, EmitEmptySlices: pkg.EmitEmptySlices, + UsePointers: pkg.UsePointers, Package: pkg.Name, Out: pkg.Path, Overrides: pkg.Overrides,