Skip to content

[PART 2] pgx v5 support (#1823) #1874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ The `gen` mapping supports the following keys:
- `out`:
- Output directory for generated code.
- `sql_package`:
- Either `pgx/v4` or `database/sql`. Defaults to `database/sql`.
- Either `pgx/v4`, `pgx/v5` or `database/sql`. Defaults to `database/sql`.
- `emit_db_tags`:
- If true, add DB tags to generated structs. Defaults to `false`.
- `emit_prepared_queries`:
Expand Down Expand Up @@ -363,7 +363,7 @@ Each mapping in the `packages` collection has the following keys:
- `engine`:
- Either `postgresql` or `mysql`. Defaults to `postgresql`.
- `sql_package`:
- Either `pgx/v4` or `database/sql`. Defaults to `database/sql`.
- Either `pgx/v4`, `pgx/v5` or `database/sql`. Defaults to `database/sql`.
- `emit_db_tags`:
- If true, add DB tags to generated structs. Defaults to `false`.
- `emit_prepared_queries`:
Expand Down
51 changes: 48 additions & 3 deletions docs/reference/datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
## Arrays

PostgreSQL [arrays](https://www.postgresql.org/docs/current/arrays.html) are
materialized as Go slices. Currently, only one-dimensional arrays are
supported.
materialized as Go slices. Currently, the `pgx/v5` sql package only supports multidimensional arrays.

```sql
CREATE TABLE places (
Expand All @@ -26,6 +25,7 @@ type Place struct {

All PostgreSQL time and date types are returned as `time.Time` structs. For
null time or date values, the `NullTime` type from `database/sql` is used.
The `pgx/v5` sql package uses the appropriate pgx types.

```sql
CREATE TABLE authors (
Expand Down Expand Up @@ -86,7 +86,7 @@ type Store struct {
## Null

For structs, null values are represented using the appropriate type from the
`database/sql` package.
`database/sql` or `pgx` package.

```sql
CREATE TABLE authors (
Expand Down Expand Up @@ -132,3 +132,48 @@ type Author struct {
ID uuid.UUID
}
```

## JSON

By default, sqlc will generate the `[]byte`, `pgtype.JSON` or `json.RawMessage` for JSON column type.
But if you use the `pgx/v5` sql package then you can specify a some struct instead of default type.
The `pgx` implementation will marshall/unmarshall the struct automatically.

```go
package dto

type BookData struct {
Genres []string `json:"genres"`
Title string `json:"title"`
Published bool `json:"published"`
}
```

```sql
CREATE TABLE books (
data jsonb
);
```

```json
{
"overrides": [
{
"column": "books.data",
"go_type": "*example.com/db/dto.BookData"
}
]
}
```

```go
package db

import (
"example.com/db/dto"
)

type Book struct {
Data *dto.BookData
}
```
6 changes: 3 additions & 3 deletions docs/reference/query-annotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) {

## `:batchexec`

__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__
__NOTE: This command only works with PostgreSQL using the `pgx/v4` and `pgx/v5` drivers and outputting Go code.__

The generated method will return a batch object. The batch object will have
the following methods:
Expand Down Expand Up @@ -147,7 +147,7 @@ func (b *DeleteBookBatchResults) Close() error {

## `:batchmany`

__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__
__NOTE: This command only works with PostgreSQL using the `pgx/v4` and `pgx/v5` drivers and outputting Go code.__

The generated method will return a batch object. The batch object will have
the following methods:
Expand Down Expand Up @@ -183,7 +183,7 @@ func (b *BooksByTitleYearBatchResults) Close() error {

## `:batchone`

__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__
__NOTE: This command only works with PostgreSQL using the `pgx/v4` and `pgx/v5` drivers and outputting Go code.__

The generated method will return a batch object. The batch object will have
the following methods:
Expand Down
13 changes: 12 additions & 1 deletion internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (

"github.com/spf13/cobra"
"github.com/spf13/pflag"
yaml "gopkg.in/yaml.v3"
"gopkg.in/yaml.v3"

"github.com/kyleconroy/sqlc/internal/codegen/golang"
"github.com/kyleconroy/sqlc/internal/config"
"github.com/kyleconroy/sqlc/internal/debug"
"github.com/kyleconroy/sqlc/internal/info"
Expand Down Expand Up @@ -112,6 +113,16 @@ func ParseEnv(c *cobra.Command) Env {
}
}

func (e *Env) Validate(cfg *config.Config) error {
for _, sql := range cfg.SQL {
if sql.Gen.Go != nil && sql.Gen.Go.SQLPackage == golang.SQLPackagePGXV5 && !e.ExperimentalFeatures {
return fmt.Errorf("'pgx/v5' golang sql package requires enabled '--experimental' flag")
}
}

return nil
}

func getConfigPath(stderr io.Writer, f *pflag.Flag) (string, string) {
if f != nil && f.Changed {
file := f.Value.String()
Expand Down
9 changes: 7 additions & 2 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
return nil, err
}

if err := e.Validate(conf); err != nil {
fmt.Fprintf(stderr, "error validating %s: %s\n", base, err)
return nil, err
}

output := map[string]string{}
errored := false

Expand Down Expand Up @@ -194,7 +199,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
}

result, failed := parse(ctx, e, name, dir, sql.SQL, combo, parseOpts, stderr)
result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr)
if failed {
if packageRegion != nil {
packageRegion.End()
Expand Down Expand Up @@ -233,7 +238,7 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
return output, nil
}

func parse(ctx context.Context, e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
if debug.Traced {
defer trace.StartRegion(ctx, "parse").End()
}
Expand Down
35 changes: 28 additions & 7 deletions internal/codegen/golang/driver.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
package golang

import (
"github.com/kyleconroy/sqlc/internal/plugin"
)

type SQLDriver int

const (
SQLPackagePGXV4 string = "pgx/v4"
SQLPackagePGXV5 string = "pgx/v5"
SQLPackageStandard string = "database/sql"
)

const (
SQLDriverPGXV4 SQLDriver = iota
SQLDriverPGXV5
SQLDriverLibPQ
)

func parseDriver(settings *plugin.Settings) SQLDriver {
if settings.Go.SqlPackage == "pgx/v4" {
func parseDriver(sqlPackage string) SQLDriver {
switch sqlPackage {
case SQLPackagePGXV4:
return SQLDriverPGXV4
} else {
case SQLPackagePGXV5:
return SQLDriverPGXV5
default:
return SQLDriverLibPQ
}
}

func (d SQLDriver) IsPGX() bool {
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
}

func (d SQLDriver) Package() string {
switch d {
case SQLDriverPGXV4:
return SQLPackagePGXV4
case SQLDriverPGXV5:
return SQLPackagePGXV5
default:
return SQLPackageStandard
}
}
8 changes: 4 additions & 4 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
type tmplCtx struct {
Q string
Package string
SQLPackage SQLPackage
SQLDriver SQLDriver
Enums []Enum
Structs []Struct
GoQueries []Query
Expand Down Expand Up @@ -91,7 +91,7 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
EmitAllEnumValues: golang.EmitAllEnumValues,
UsesCopyFrom: usesCopyFrom(queries),
UsesBatch: usesBatch(queries),
SQLPackage: SQLPackageFromString(golang.SqlPackage),
SQLDriver: parseDriver(golang.SqlPackage),
Q: "`",
Package: golang.Package,
GoQueries: queries,
Expand All @@ -100,11 +100,11 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
SqlcVersion: req.SqlcVersion,
}

if tctx.UsesCopyFrom && tctx.SQLPackage != SQLPackagePGX {
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() {
return nil, errors.New(":copyfrom is only supported by pgx")
}

if tctx.UsesBatch && tctx.SQLPackage != SQLPackagePGX {
if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {
return nil, errors.New(":batch* commands are only supported by pgx")
}

Expand Down
3 changes: 3 additions & 0 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func goType(req *plugin.CodeGenRequest, col *plugin.Column) string {
}
typ := goInnerType(req, col)
if col.IsArray {
if parseDriver(req.Settings.Go.SqlPackage) == SQLDriverPGXV5 {
return "pgtype.Array[" + typ + "]"
}
return "[]" + typ
}
return typ
Expand Down
54 changes: 27 additions & 27 deletions internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (i *importer) Imports(filename string) [][]ImportSpec {
case copyfromFileName:
return mergeImports(i.copyfromImports())
case batchFileName:
return mergeImports(i.batchImports(filename))
return mergeImports(i.batchImports())
default:
return mergeImports(i.queryImports(filename))
}
Expand All @@ -114,11 +114,14 @@ func (i *importer) dbImports() fileImports {
{Path: "context"},
}

sqlpkg := SQLPackageFromString(i.Settings.Go.SqlPackage)
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
switch sqlpkg {
case SQLPackagePGX:
case SQLDriverPGXV4:
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"})
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v4"})
case SQLDriverPGXV5:
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"})
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"})
default:
std = append(std, ImportSpec{Path: "database/sql"})
if i.Settings.Go.EmitPreparedQueries {
Expand All @@ -136,22 +139,8 @@ var stdlibTypes = map[string]string{
"time.Time": "time",
"net.IP": "net",
"net.HardwareAddr": "net",
}

var pgtypeTypes = map[string]struct{}{
"pgtype.CIDR": {},
"pgtype.Daterange": {},
"pgtype.Inet": {},
"pgtype.Int4range": {},
"pgtype.Int8range": {},
"pgtype.JSON": {},
"pgtype.JSONB": {},
"pgtype.Hstore": {},
"pgtype.Macaddr": {},
"pgtype.Numeric": {},
"pgtype.Numrange": {},
"pgtype.Tsrange": {},
"pgtype.Tstzrange": {},
"netip.Addr": "net/netip",
"netip.Prefix": "net/netip",
}

var pqtypeTypes = map[string]struct{}{
Expand All @@ -169,12 +158,14 @@ func buildImports(settings *plugin.Settings, queries []Query, uses func(string)
std["database/sql"] = struct{}{}
}

sqlpkg := SQLPackageFromString(settings.Go.SqlPackage)
sqlpkg := parseDriver(settings.Go.SqlPackage)
for _, q := range queries {
if q.Cmd == metadata.CmdExecResult {
switch sqlpkg {
case SQLPackagePGX:
case SQLDriverPGXV4:
pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{}
case SQLDriverPGXV5:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{}
default:
std["database/sql"] = struct{}{}
}
Expand All @@ -187,15 +178,18 @@ func buildImports(settings *plugin.Settings, queries []Query, uses func(string)
}
}

for typeName, _ := range pgtypeTypes {
if uses(typeName) {
if uses("pgtype.") {
if sqlpkg == SQLDriverPGXV5 {
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{}
} else {
pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{}
}
}

for typeName, _ := range pqtypeTypes {
if uses(typeName) {
pkg[ImportSpec{Path: "github.com/tabbed/pqtype"}] = struct{}{}
break
}
}

Expand Down Expand Up @@ -373,8 +367,8 @@ func (i *importer) queryImports(filename string) fileImports {
std["context"] = struct{}{}
}

sqlpkg := SQLPackageFromString(i.Settings.Go.SqlPackage)
if sliceScan() && sqlpkg != SQLPackagePGX {
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
if sliceScan() && !sqlpkg.IsPGX() {
pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{}
}

Expand Down Expand Up @@ -409,7 +403,7 @@ func (i *importer) copyfromImports() fileImports {
return sortedImports(std, pkg)
}

func (i *importer) batchImports(filename string) fileImports {
func (i *importer) batchImports() fileImports {
batchQueries := make([]Query, 0, len(i.Queries))
for _, q := range i.Queries {
if usesBatch([]Query{q}) {
Expand Down Expand Up @@ -452,7 +446,13 @@ func (i *importer) batchImports(filename string) fileImports {

std["context"] = struct{}{}
std["errors"] = struct{}{}
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
switch sqlpkg {
case SQLDriverPGXV4:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}
case SQLDriverPGXV5:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{}
}

return sortedImports(std, pkg)
}
Loading