diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index d26691750e..794e07f8b6 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -490,25 +490,31 @@ func (r Result) goInnerType(col core.Column) string { switch columnType { case "serial", "pg_catalog.serial4": - return "int32" + if notNull { + return "int32" + } + return "sql.NullInt32" case "bigserial", "pg_catalog.serial8": if notNull { return "int64" } - return "sql.NullInt64" // unnecessay else + return "sql.NullInt64" case "smallserial", "pg_catalog.serial2": return "int16" case "integer", "int", "pg_catalog.int4": - return "int32" + if notNull { + return "int32" + } + return "sql.NullInt32" case "bigint", "pg_catalog.int8": if notNull { return "int64" } - return "sql.NullInt64" // unnecessary else + return "sql.NullInt64" case "smallint", "pg_catalog.int2": return "int16" @@ -517,19 +523,19 @@ func (r Result) goInnerType(col core.Column) string { if notNull { return "float64" } - return "sql.NullFloat64" // unnecessary else + return "sql.NullFloat64" case "real", "pg_catalog.float4": if notNull { return "float32" - } // unnecessary else - return "sql.NullFloat64" // IMPORTANT: Change to sql.NullFloat32 after updating the go.mod file + } + return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file case "bool", "pg_catalog.bool": if notNull { return "bool" } - return "sql.NullBool" // unnecessary else + return "sql.NullBool" case "jsonb": return "json.RawMessage" @@ -541,13 +547,13 @@ func (r Result) goInnerType(col core.Column) string { if notNull { return "time.Time" } - return "pq.NullTime" // unnecessary else + return "sql.NullTime" case "text", "pg_catalog.varchar", "pg_catalog.bpchar": if notNull { return "string" } - return "sql.NullString" // unnecessary else + return "sql.NullString" case "uuid": return "uuid.UUID" diff --git a/internal/dinosql/gen_test.go b/internal/dinosql/gen_test.go index 16846b84cc..c65950f2b1 100644 --- a/internal/dinosql/gen_test.go +++ b/internal/dinosql/gen_test.go @@ -78,19 +78,39 @@ func TestColumnsToStruct(t *testing.T) { func TestInnerType(t *testing.T) { r := Result{} - for _, tc := range []struct { - col pg.Column - expected string - }{ - { - pg.Column{Name: "created", DataType: "timestamptz", NotNull: true}, - "time.Time", - }, - } { - tt := tc - t.Run(tt.col.Name+"-"+tt.col.DataType, func(t *testing.T) { - if diff := cmp.Diff(tt.expected, r.goType(tt.col)); diff != "" { - t.Errorf("struct mismatch: \n%s", diff) + types := map[string]string{ + "timestamptz": "time.Time", + "integer": "int32", + "int": "int32", + "pg_catalog.int4": "int32", + } + for k, v := range types { + dbType := k + goType := v + t.Run(k+"-"+v, func(t *testing.T) { + col := pg.Column{DataType: dbType, NotNull: true} + if goType != r.goType(col) { + t.Errorf("expected Go type for %s to be %s, not %s", dbType, goType, r.goType(col)) + } + }) + } +} + +func TestNullInnerType(t *testing.T) { + r := Result{} + types := map[string]string{ + "timestamptz": "sql.NullTime", + "integer": "sql.NullInt32", + "int": "sql.NullInt32", + "pg_catalog.int4": "sql.NullInt32", + } + for k, v := range types { + dbType := k + goType := v + t.Run(k+"-"+v, func(t *testing.T) { + col := pg.Column{DataType: dbType, NotNull: false} + if goType != r.goType(col) { + t.Errorf("expected Go type for %s to be %s, not %s", dbType, goType, r.goType(col)) } }) }