Skip to content

Commit f890bb2

Browse files
committed
internal/dinosql: Use more database/sql null types
Use sql.NullInt32 and sql.NullTime
1 parent 9211826 commit f890bb2

File tree

2 files changed

+49
-23
lines changed

2 files changed

+49
-23
lines changed

internal/dinosql/gen.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -490,25 +490,31 @@ func (r Result) goInnerType(col core.Column) string {
490490

491491
switch columnType {
492492
case "serial", "pg_catalog.serial4":
493-
return "int32"
493+
if notNull {
494+
return "int32"
495+
}
496+
return "sql.NullInt32"
494497

495498
case "bigserial", "pg_catalog.serial8":
496499
if notNull {
497500
return "int64"
498501
}
499-
return "sql.NullInt64" // unnecessay else
502+
return "sql.NullInt64"
500503

501504
case "smallserial", "pg_catalog.serial2":
502505
return "int16"
503506

504507
case "integer", "int", "pg_catalog.int4":
505-
return "int32"
508+
if notNull {
509+
return "int32"
510+
}
511+
return "sql.NullInt32"
506512

507513
case "bigint", "pg_catalog.int8":
508514
if notNull {
509515
return "int64"
510516
}
511-
return "sql.NullInt64" // unnecessary else
517+
return "sql.NullInt64"
512518

513519
case "smallint", "pg_catalog.int2":
514520
return "int16"
@@ -517,19 +523,19 @@ func (r Result) goInnerType(col core.Column) string {
517523
if notNull {
518524
return "float64"
519525
}
520-
return "sql.NullFloat64" // unnecessary else
526+
return "sql.NullFloat64"
521527

522528
case "real", "pg_catalog.float4":
523529
if notNull {
524530
return "float32"
525-
} // unnecessary else
526-
return "sql.NullFloat64" // IMPORTANT: Change to sql.NullFloat32 after updating the go.mod file
531+
}
532+
return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file
527533

528534
case "bool", "pg_catalog.bool":
529535
if notNull {
530536
return "bool"
531537
}
532-
return "sql.NullBool" // unnecessary else
538+
return "sql.NullBool"
533539

534540
case "jsonb":
535541
return "json.RawMessage"
@@ -541,13 +547,13 @@ func (r Result) goInnerType(col core.Column) string {
541547
if notNull {
542548
return "time.Time"
543549
}
544-
return "pq.NullTime" // unnecessary else
550+
return "sql.NullTime"
545551

546552
case "text", "pg_catalog.varchar", "pg_catalog.bpchar":
547553
if notNull {
548554
return "string"
549555
}
550-
return "sql.NullString" // unnecessary else
556+
return "sql.NullString"
551557

552558
case "uuid":
553559
return "uuid.UUID"

internal/dinosql/gen_test.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,39 @@ func TestColumnsToStruct(t *testing.T) {
7878

7979
func TestInnerType(t *testing.T) {
8080
r := Result{}
81-
for _, tc := range []struct {
82-
col pg.Column
83-
expected string
84-
}{
85-
{
86-
pg.Column{Name: "created", DataType: "timestamptz", NotNull: true},
87-
"time.Time",
88-
},
89-
} {
90-
tt := tc
91-
t.Run(tt.col.Name+"-"+tt.col.DataType, func(t *testing.T) {
92-
if diff := cmp.Diff(tt.expected, r.goType(tt.col)); diff != "" {
93-
t.Errorf("struct mismatch: \n%s", diff)
81+
types := map[string]string{
82+
"timestamptz": "time.Time",
83+
"integer": "int32",
84+
"int": "int32",
85+
"pg_catalog.int4": "int32",
86+
}
87+
for k, v := range types {
88+
dbType := k
89+
goType := v
90+
t.Run(k+"-"+v, func(t *testing.T) {
91+
col := pg.Column{DataType: dbType, NotNull: true}
92+
if goType != r.goType(col) {
93+
t.Errorf("expected Go type for %s to be %s, not %s", dbType, goType, r.goType(col))
94+
}
95+
})
96+
}
97+
}
98+
99+
func TestNullInnerType(t *testing.T) {
100+
r := Result{}
101+
types := map[string]string{
102+
"timestamptz": "sql.NullTime",
103+
"integer": "sql.NullInt32",
104+
"int": "sql.NullInt32",
105+
"pg_catalog.int4": "sql.NullInt32",
106+
}
107+
for k, v := range types {
108+
dbType := k
109+
goType := v
110+
t.Run(k+"-"+v, func(t *testing.T) {
111+
col := pg.Column{DataType: dbType, NotNull: false}
112+
if goType != r.goType(col) {
113+
t.Errorf("expected Go type for %s to be %s, not %s", dbType, goType, r.goType(col))
94114
}
95115
})
96116
}

0 commit comments

Comments
 (0)