Skip to content

Commit 0dfbfb6

Browse files
committed
1 parent aa5e345 commit 0dfbfb6

File tree

16 files changed

+214
-133
lines changed

16 files changed

+214
-133
lines changed

internal/driver.go

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,14 @@
11
package golang
22

3-
type SQLDriver string
3+
import "github.com/sqlc-dev/sqlc-gen-go/internal/opts"
44

5-
const (
6-
SQLPackagePGXV4 string = "pgx/v4"
7-
SQLPackagePGXV5 string = "pgx/v5"
8-
SQLPackageStandard string = "database/sql"
9-
)
10-
11-
const (
12-
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
13-
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
14-
SQLDriverLibPQ = "github.com/lib/pq"
15-
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
16-
)
17-
18-
func parseDriver(sqlPackage string) SQLDriver {
5+
func parseDriver(sqlPackage string) opts.SQLDriver {
196
switch sqlPackage {
20-
case SQLPackagePGXV4:
21-
return SQLDriverPGXV4
22-
case SQLPackagePGXV5:
23-
return SQLDriverPGXV5
24-
default:
25-
return SQLDriverLibPQ
26-
}
27-
}
28-
29-
func (d SQLDriver) IsPGX() bool {
30-
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
31-
}
32-
33-
func (d SQLDriver) IsGoSQLDriverMySQL() bool {
34-
return d == SQLDriverGoSQLDriverMySQL
35-
}
36-
37-
func (d SQLDriver) Package() string {
38-
switch d {
39-
case SQLDriverPGXV4:
40-
return SQLPackagePGXV4
41-
case SQLDriverPGXV5:
42-
return SQLPackagePGXV5
7+
case opts.SQLPackagePGXV4:
8+
return opts.SQLDriverPGXV4
9+
case opts.SQLPackagePGXV5:
10+
return opts.SQLDriverPGXV5
4311
default:
44-
return SQLPackageStandard
12+
return opts.SQLDriverLibPQ
4513
}
4614
}

internal/field.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import (
66
"sort"
77
"strings"
88

9-
"github.com/sqlc-dev/plugin-sdk-go/plugin"
109
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
10+
"github.com/sqlc-dev/plugin-sdk-go/plugin"
1111
)
1212

1313
type Field struct {

internal/gen.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ import (
1010
"strings"
1111
"text/template"
1212

13+
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
14+
"github.com/sqlc-dev/plugin-sdk-go/sdk"
1315
"github.com/sqlc-dev/plugin-sdk-go/metadata"
1416
"github.com/sqlc-dev/plugin-sdk-go/plugin"
15-
"github.com/sqlc-dev/plugin-sdk-go/sdk"
16-
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
1717
)
1818

1919
type tmplCtx struct {
2020
Q string
2121
Package string
22-
SQLDriver SQLDriver
22+
SQLDriver opts.SQLDriver
2323
Enums []Enum
2424
Structs []Struct
2525
GoQueries []Query
@@ -189,15 +189,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
189189
OmitSqlcVersion: options.OmitSqlcVersion,
190190
}
191191

192-
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL {
192+
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL {
193193
return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql")
194194
}
195195

196-
if tctx.UsesCopyFrom && options.SqlDriver == SQLDriverGoSQLDriverMySQL {
196+
if tctx.UsesCopyFrom && options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL {
197197
if err := checkNoTimesForMySQLCopyFrom(queries); err != nil {
198198
return nil, err
199199
}
200-
tctx.SQLDriver = SQLDriverGoSQLDriverMySQL
200+
tctx.SQLDriver = opts.SQLDriverGoSQLDriverMySQL
201201
}
202202

203203
if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {
@@ -209,6 +209,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
209209
"comment": sdk.DoubleSlashComment,
210210
"escape": sdk.EscapeBacktick,
211211
"imports": i.Imports,
212+
"hasImports": i.HasImports,
212213
"hasPrefix": strings.HasPrefix,
213214

214215
// These methods are Go specific, they do not belong in the codegen package

internal/go_type.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package golang
33
import (
44
"strings"
55

6-
"github.com/sqlc-dev/plugin-sdk-go/plugin"
7-
"github.com/sqlc-dev/plugin-sdk-go/sdk"
86
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
7+
"github.com/sqlc-dev/plugin-sdk-go/sdk"
8+
"github.com/sqlc-dev/plugin-sdk-go/plugin"
99
)
1010

1111
func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) {

internal/imports.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55
"sort"
66
"strings"
77

8-
"github.com/sqlc-dev/plugin-sdk-go/metadata"
98
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
9+
"github.com/sqlc-dev/plugin-sdk-go/metadata"
1010
)
1111

1212
type fileImports struct {
@@ -75,6 +75,11 @@ func (i *importer) usesType(typ string) bool {
7575
return false
7676
}
7777

78+
func (i *importer) HasImports(filename string) bool {
79+
imports := i.Imports(filename)
80+
return len(imports[0]) != 0 || len(imports[1]) != 0
81+
}
82+
7883
func (i *importer) Imports(filename string) [][]ImportSpec {
7984
dbFileName := "db.go"
8085
if i.Options.OutputDbFileName != "" {
@@ -121,10 +126,10 @@ func (i *importer) dbImports() fileImports {
121126

122127
sqlpkg := parseDriver(i.Options.SqlPackage)
123128
switch sqlpkg {
124-
case SQLDriverPGXV4:
129+
case opts.SQLDriverPGXV4:
125130
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"})
126131
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v4"})
127-
case SQLDriverPGXV5:
132+
case opts.SQLDriverPGXV5:
128133
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"})
129134
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"})
130135
default:
@@ -167,9 +172,9 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
167172
for _, q := range queries {
168173
if q.Cmd == metadata.CmdExecResult {
169174
switch sqlpkg {
170-
case SQLDriverPGXV4:
175+
case opts.SQLDriverPGXV4:
171176
pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{}
172-
case SQLDriverPGXV5:
177+
case opts.SQLDriverPGXV5:
173178
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{}
174179
default:
175180
std["database/sql"] = struct{}{}
@@ -184,7 +189,7 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
184189
}
185190

186191
if uses("pgtype.") {
187-
if sqlpkg == SQLDriverPGXV5 {
192+
if sqlpkg == opts.SQLDriverPGXV5 {
188193
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{}
189194
} else {
190195
pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{}
@@ -424,7 +429,7 @@ func (i *importer) copyfromImports() fileImports {
424429
})
425430

426431
std["context"] = struct{}{}
427-
if i.Options.SqlDriver == SQLDriverGoSQLDriverMySQL {
432+
if i.Options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL {
428433
std["io"] = struct{}{}
429434
std["fmt"] = struct{}{}
430435
std["sync/atomic"] = struct{}{}
@@ -476,9 +481,9 @@ func (i *importer) batchImports() fileImports {
476481
std["errors"] = struct{}{}
477482
sqlpkg := parseDriver(i.Options.SqlPackage)
478483
switch sqlpkg {
479-
case SQLDriverPGXV4:
484+
case opts.SQLDriverPGXV4:
480485
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}
481-
case SQLDriverPGXV5:
486+
case opts.SQLDriverPGXV5:
482487
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{}
483488
}
484489

internal/mysql_type.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package golang
33
import (
44
"log"
55

6-
"github.com/sqlc-dev/plugin-sdk-go/plugin"
6+
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
77
"github.com/sqlc-dev/plugin-sdk-go/sdk"
88
"github.com/sqlc-dev/sqlc-gen-go/internal/debug"
9-
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
9+
"github.com/sqlc-dev/plugin-sdk-go/plugin"
1010
)
1111

1212
func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
@@ -31,14 +31,31 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C
3131
} else {
3232
if notNull {
3333
if unsigned {
34-
return "uint32"
34+
return "uint8"
3535
}
36-
return "int32"
36+
return "int8"
37+
}
38+
// The database/sql package does not have a sql.NullInt8 type, so we
39+
// use the smallest type they have which is NullInt16
40+
return "sql.NullInt16"
41+
}
42+
43+
case "year":
44+
if notNull {
45+
return "int16"
46+
}
47+
return "sql.NullInt16"
48+
49+
case "smallint":
50+
if notNull {
51+
if unsigned {
52+
return "uint16"
3753
}
38-
return "sql.NullInt32"
54+
return "int16"
3955
}
56+
return "sql.NullInt16"
4057

41-
case "int", "integer", "smallint", "mediumint", "year":
58+
case "int", "integer", "mediumint":
4259
if notNull {
4360
if unsigned {
4461
return "uint32"

internal/opts/enum.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package opts
2+
3+
import "fmt"
4+
5+
type SQLDriver string
6+
7+
const (
8+
SQLPackagePGXV4 string = "pgx/v4"
9+
SQLPackagePGXV5 string = "pgx/v5"
10+
SQLPackageStandard string = "database/sql"
11+
)
12+
13+
var validPackages = map[string]struct{}{
14+
string(SQLPackagePGXV4): {},
15+
string(SQLPackagePGXV5): {},
16+
string(SQLPackageStandard): {},
17+
}
18+
19+
func validatePackage(sqlPackage string) error {
20+
if _, found := validPackages[sqlPackage]; !found {
21+
return fmt.Errorf("unknown SQL package: %s", sqlPackage)
22+
}
23+
return nil
24+
}
25+
26+
const (
27+
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
28+
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
29+
SQLDriverLibPQ = "github.com/lib/pq"
30+
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
31+
)
32+
33+
var validDrivers = map[string]struct{}{
34+
string(SQLDriverPGXV4): {},
35+
string(SQLDriverPGXV5): {},
36+
string(SQLDriverLibPQ): {},
37+
string(SQLDriverGoSQLDriverMySQL): {},
38+
}
39+
40+
func validateDriver(sqlDriver string) error {
41+
if _, found := validDrivers[sqlDriver]; !found {
42+
return fmt.Errorf("unknown SQL driver: %s", sqlDriver)
43+
}
44+
return nil
45+
}
46+
47+
func (d SQLDriver) IsPGX() bool {
48+
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
49+
}
50+
51+
func (d SQLDriver) IsGoSQLDriverMySQL() bool {
52+
return d == SQLDriverGoSQLDriverMySQL
53+
}
54+
55+
func (d SQLDriver) Package() string {
56+
switch d {
57+
case SQLDriverPGXV4:
58+
return SQLPackagePGXV4
59+
case SQLDriverPGXV5:
60+
return SQLPackagePGXV5
61+
default:
62+
return SQLPackageStandard
63+
}
64+
}

internal/opts/options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
9494
}
9595
}
9696

97+
if options.SqlPackage != "" {
98+
if err := validatePackage(options.SqlPackage); err != nil {
99+
return nil, fmt.Errorf("invalid options: %s", err)
100+
}
101+
}
102+
103+
if options.SqlDriver != "" {
104+
if err := validateDriver(options.SqlDriver); err != nil {
105+
return nil, fmt.Errorf("invalid options: %s", err)
106+
}
107+
}
108+
97109
if options.QueryParameterLimit == nil {
98110
options.QueryParameterLimit = new(int32)
99111
*options.QueryParameterLimit = 1

0 commit comments

Comments
 (0)