From cb114d311ba30ce45c875995d6e6e291e3d3a7cd Mon Sep 17 00:00:00 2001 From: Jordan Pittier Date: Thu, 30 Mar 2023 11:57:02 +0200 Subject: [PATCH 1/2] Customizable batch output file name (add OutputBatchFileName field) This commit adds the possibility to customize the batch output file name. Example configuration: ``` version: "1" packages: - name: db path: internal/db queries: internal/db schema: migrations engine: postgresql output_batch_file_name: batch_gen.go output_models_file_name: model_gen.go ``` --- docs/reference/config.md | 5 ++ internal/cmd/shim.go | 1 + internal/codegen/golang/gen.go | 3 ++ internal/codegen/golang/imports.go | 3 ++ internal/config/config.go | 1 + internal/config/v_one.go | 2 + .../testdata/codegen_json/gen/codegen.json | 3 +- .../output_file_names/pgx/v4/sqlc.json | 1 + .../output_file_names/pgx/v5/sqlc.json | 1 + .../output_file_names/stdlib/sqlc.json | 1 + internal/plugin/codegen.pb.go | 8 ++++ internal/plugin/codegen_vtproto.pb.go | 48 +++++++++++++++++++ protos/plugin/codegen.proto | 1 + 13 files changed, 77 insertions(+), 1 deletion(-) diff --git a/docs/reference/config.md b/docs/reference/config.md index 2d663709ec..32d2955ec2 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -113,6 +113,8 @@ The `gen` mapping supports the following keys: that returns all valid enum values. - `json_tags_case_style`: - `camel` for camelCase, `pascal` for PascalCase, `snake` for snake_case or `none` to use the column name in the DB. Defaults to `none`. +- `output_batch_file_name`: + - Customize the name of the batch file. Defaults to `batch.go`. - `output_db_file_name`: - Customize the name of the db file. Defaults to `db.go`. - `output_models_file_name`: @@ -343,6 +345,7 @@ packages: emit_enum_valid_method: false emit_all_enum_values: false json_tags_case_style: "camel" + output_batch_file_name: "batch.go" output_db_file_name: "db.go" output_models_file_name: "models.go" output_querier_file_name: "querier.go" @@ -394,6 +397,8 @@ Each mapping in the `packages` collection has the following keys: that returns all valid enum values. - `json_tags_case_style`: - `camel` for camelCase, `pascal` for PascalCase, `snake` for snake_case or `none` to use the column name in the DB. Defaults to `none`. +- `output_batch_file_name`: + - Customize the name of the batch file. Defaults to `batch.go`. - `output_db_file_name`: - Customize the name of the db file. Defaults to `db.go`. - `output_models_file_name`: diff --git a/internal/cmd/shim.go b/internal/cmd/shim.go index 66fedd322c..49a7718f8b 100644 --- a/internal/cmd/shim.go +++ b/internal/cmd/shim.go @@ -94,6 +94,7 @@ func pluginGoCode(s config.SQLGo) *plugin.GoCode { Out: s.Out, SqlPackage: s.SQLPackage, OutputDbFileName: s.OutputDBFileName, + OutputBatchFileName: s.OutputBatchFileName, OutputModelsFileName: s.OutputModelsFileName, OutputQuerierFileName: s.OutputQuerierFileName, OutputFilesSuffix: s.OutputFilesSuffix, diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index dd802f0150..f148abe0e7 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -155,6 +155,9 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie // TODO(Jille): Make this configurable. batchFileName := "batch.go" + if golang.OutputBatchFileName != "" { + batchFileName = golang.OutputBatchFileName + } if err := execute(dbFileName, "dbFile"); err != nil { return nil, err diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 6da4b9c5e5..209d4724d0 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -91,6 +91,9 @@ func (i *importer) Imports(filename string) [][]ImportSpec { } copyfromFileName := "copyfrom.go" batchFileName := "batch.go" + if i.Settings.Go.OutputBatchFileName != "" { + batchFileName = i.Settings.Go.OutputBatchFileName + } switch filename { case dbFileName: diff --git a/internal/config/config.go b/internal/config/config.go index 0dfd57bcc9..9cfc5d0402 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -128,6 +128,7 @@ type SQLGo struct { Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` Rename map[string]string `json:"rename,omitempty" yaml:"rename"` SQLPackage string `json:"sql_package" yaml:"sql_package"` + OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"` OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` diff --git a/internal/config/v_one.go b/internal/config/v_one.go index b44ff659e4..ca4165d62e 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -38,6 +38,7 @@ type v1PackageSettings struct { JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` SQLPackage string `json:"sql_package" yaml:"sql_package"` Overrides []Override `json:"overrides" yaml:"overrides"` + OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"` OutputDBFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"` OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"` OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` @@ -139,6 +140,7 @@ func (c *V1GenerateSettings) Translate() Config { SQLPackage: pkg.SQLPackage, Overrides: pkg.Overrides, JSONTagsCaseStyle: pkg.JSONTagsCaseStyle, + OutputBatchFileName: pkg.OutputBatchFileName, OutputDBFileName: pkg.OutputDBFileName, OutputModelsFileName: pkg.OutputModelsFileName, OutputQuerierFileName: pkg.OutputQuerierFileName, diff --git a/internal/endtoend/testdata/codegen_json/gen/codegen.json b/internal/endtoend/testdata/codegen_json/gen/codegen.json index 4018d68a4b..d5d1481c6b 100644 --- a/internal/endtoend/testdata/codegen_json/gen/codegen.json +++ b/internal/endtoend/testdata/codegen_json/gen/codegen.json @@ -37,7 +37,8 @@ "emit_enum_valid_method": false, "emit_all_enum_values": false, "inflection_exclude_table_names": [], - "emit_pointers_for_null_types": false + "emit_pointers_for_null_types": false, + "output_batch_file_name": "" }, "json": { "out": "gen", diff --git a/internal/endtoend/testdata/output_file_names/pgx/v4/sqlc.json b/internal/endtoend/testdata/output_file_names/pgx/v4/sqlc.json index ecf9722960..3926a6d114 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v4/sqlc.json +++ b/internal/endtoend/testdata/output_file_names/pgx/v4/sqlc.json @@ -9,6 +9,7 @@ "schema": "query.sql", "queries": "query.sql", "emit_interface": true, + "output_batch_file_name": "batch_gen.go", "output_db_file_name": "db_gen.go", "output_models_file_name": "models_gen.go", "output_querier_file_name": "querier_gen.go" diff --git a/internal/endtoend/testdata/output_file_names/pgx/v5/sqlc.json b/internal/endtoend/testdata/output_file_names/pgx/v5/sqlc.json index 0819626d87..105a8a2ef4 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v5/sqlc.json +++ b/internal/endtoend/testdata/output_file_names/pgx/v5/sqlc.json @@ -9,6 +9,7 @@ "schema": "query.sql", "queries": "query.sql", "emit_interface": true, + "output_batch_file_name": "batch_gen.go", "output_db_file_name": "db_gen.go", "output_models_file_name": "models_gen.go", "output_querier_file_name": "querier_gen.go" diff --git a/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json b/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json index 822f727f54..ee4a7d651f 100644 --- a/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json +++ b/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json @@ -7,6 +7,7 @@ "schema": "query.sql", "queries": "query.sql", "emit_interface": true, + "output_batch_file_name": "batch_gen.go", "output_db_file_name": "db_gen.go", "output_models_file_name": "models_gen.go", "output_querier_file_name": "querier_gen.go" diff --git a/internal/plugin/codegen.pb.go b/internal/plugin/codegen.pb.go index c443da7210..414e602cc1 100644 --- a/internal/plugin/codegen.pb.go +++ b/internal/plugin/codegen.pb.go @@ -455,6 +455,7 @@ type GoCode struct { EmitAllEnumValues bool `protobuf:"varint,20,opt,name=emit_all_enum_values,json=emitAllEnumValues,proto3" json:"emit_all_enum_values,omitempty"` InflectionExcludeTableNames []string `protobuf:"bytes,21,rep,name=inflection_exclude_table_names,json=inflectionExcludeTableNames,proto3" json:"inflection_exclude_table_names,omitempty"` EmitPointersForNullTypes bool `protobuf:"varint,22,opt,name=emit_pointers_for_null_types,json=emitPointersForNullTypes,proto3" json:"emit_pointers_for_null_types,omitempty"` + OutputBatchFileName string `protobuf:"bytes,23,opt,name=output_batch_file_name,json=outputBatchFileName,proto3" json:"output_batch_file_name,omitempty"` } func (x *GoCode) Reset() { @@ -643,6 +644,13 @@ func (x *GoCode) GetEmitPointersForNullTypes() bool { return false } +func (x *GoCode) GetOutputBatchFileName() string { + if x != nil { + return x.OutputBatchFileName + } + return "" +} + type JSONCode struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/internal/plugin/codegen_vtproto.pb.go b/internal/plugin/codegen_vtproto.pb.go index 736d61ffd9..1d1d2d9a00 100644 --- a/internal/plugin/codegen_vtproto.pb.go +++ b/internal/plugin/codegen_vtproto.pb.go @@ -243,6 +243,9 @@ func (this *GoCode) EqualVT(that *GoCode) bool { if this.EmitPointersForNullTypes != that.EmitPointersForNullTypes { return false } + if this.OutputBatchFileName != that.OutputBatchFileName { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -990,6 +993,15 @@ func (m *GoCode) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.OutputBatchFileName) > 0 { + i -= len(m.OutputBatchFileName) + copy(dAtA[i:], m.OutputBatchFileName) + i = encodeVarint(dAtA, i, uint64(len(m.OutputBatchFileName))) + i-- + dAtA[i] = 0x1 + i-- + dAtA[i] = 0xba + } if m.EmitPointersForNullTypes { i-- if m.EmitPointersForNullTypes { @@ -2296,6 +2308,10 @@ func (m *GoCode) SizeVT() (n int) { if m.EmitPointersForNullTypes { n += 3 } + l = len(m.OutputBatchFileName) + if l > 0 { + n += 2 + l + sov(uint64(l)) + } if m.unknownFields != nil { n += len(m.unknownFields) } @@ -4524,6 +4540,38 @@ func (m *GoCode) UnmarshalVT(dAtA []byte) error { } } m.EmitPointersForNullTypes = bool(v != 0) + case 23: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field OutputBatchFileName", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.OutputBatchFileName = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skip(dAtA[iNdEx:]) diff --git a/protos/plugin/codegen.proto b/protos/plugin/codegen.proto index de484587ad..8d7ca43806 100644 --- a/protos/plugin/codegen.proto +++ b/protos/plugin/codegen.proto @@ -92,6 +92,7 @@ message GoCode bool emit_all_enum_values = 20; repeated string inflection_exclude_table_names = 21; bool emit_pointers_for_null_types = 22; + string output_batch_file_name = 23; } message JSONCode From c5a661e998bdc0ee64424c89b1f6a609c186fb3f Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 7 Apr 2023 13:26:25 -0700 Subject: [PATCH 2/2] test: Add tests for new batch filename config --- .../output_file_names/pgx/v4/go/batch_gen.go | 72 +++++++++++++++++++ .../output_file_names/pgx/v4/go/db_gen.go | 1 + .../pgx/v4/go/querier_gen.go | 1 + .../output_file_names/pgx/v4/query.sql | 4 ++ .../output_file_names/pgx/v5/go/batch_gen.go | 72 +++++++++++++++++++ .../output_file_names/pgx/v5/go/db_gen.go | 1 + .../pgx/v5/go/querier_gen.go | 1 + .../output_file_names/pgx/v5/query.sql | 4 ++ .../output_file_names/stdlib/sqlc.json | 1 - 9 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 internal/endtoend/testdata/output_file_names/pgx/v4/go/batch_gen.go create mode 100644 internal/endtoend/testdata/output_file_names/pgx/v5/go/batch_gen.go diff --git a/internal/endtoend/testdata/output_file_names/pgx/v4/go/batch_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v4/go/batch_gen.go new file mode 100644 index 0000000000..8ce6e2378c --- /dev/null +++ b/internal/endtoend/testdata/output_file_names/pgx/v4/go/batch_gen.go @@ -0,0 +1,72 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: batch_gen.go + +package querytest + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v4" +) + +const usersB = `-- name: UsersB :batchmany +SELECT id FROM "user" +WHERE id = $1 +` + +type UsersBBatchResults struct { + br pgx.BatchResults + tot int + closed bool +} + +func (q *Queries) UsersB(ctx context.Context, id []int64) *UsersBBatchResults { + batch := &pgx.Batch{} + for _, a := range id { + vals := []interface{}{ + a, + } + batch.Queue(usersB, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &UsersBBatchResults{br, len(id), false} +} + +func (b *UsersBBatchResults) Query(f func(int, []int64, error)) { + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var items []int64 + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) + } + continue + } + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return err + } + items = append(items, id) + } + return rows.Err() + }() + if f != nil { + f(t, items, err) + } + } +} + +func (b *UsersBBatchResults) Close() error { + b.closed = true + return b.br.Close() +} diff --git a/internal/endtoend/testdata/output_file_names/pgx/v4/go/db_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v4/go/db_gen.go index 439db75e69..f5ba3be7f8 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v4/go/db_gen.go +++ b/internal/endtoend/testdata/output_file_names/pgx/v4/go/db_gen.go @@ -15,6 +15,7 @@ type DBTX interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } func New(db DBTX) *Queries { diff --git a/internal/endtoend/testdata/output_file_names/pgx/v4/go/querier_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v4/go/querier_gen.go index b30853bfc6..5eebe5018d 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v4/go/querier_gen.go +++ b/internal/endtoend/testdata/output_file_names/pgx/v4/go/querier_gen.go @@ -10,6 +10,7 @@ import ( type Querier interface { User(ctx context.Context) ([]int64, error) + UsersB(ctx context.Context, id []int64) *UsersBBatchResults } var _ Querier = (*Queries)(nil) diff --git a/internal/endtoend/testdata/output_file_names/pgx/v4/query.sql b/internal/endtoend/testdata/output_file_names/pgx/v4/query.sql index 3191419956..ef6ee90544 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v4/query.sql +++ b/internal/endtoend/testdata/output_file_names/pgx/v4/query.sql @@ -2,3 +2,7 @@ CREATE TABLE "user" (id bigserial not null); -- name: User :many SELECT "user".* FROM "user"; + +-- name: UsersB :batchmany +SELECT * FROM "user" +WHERE id = $1; diff --git a/internal/endtoend/testdata/output_file_names/pgx/v5/go/batch_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v5/go/batch_gen.go new file mode 100644 index 0000000000..f09bc675cd --- /dev/null +++ b/internal/endtoend/testdata/output_file_names/pgx/v5/go/batch_gen.go @@ -0,0 +1,72 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: batch_gen.go + +package querytest + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" +) + +const usersB = `-- name: UsersB :batchmany +SELECT id FROM "user" +WHERE id = $1 +` + +type UsersBBatchResults struct { + br pgx.BatchResults + tot int + closed bool +} + +func (q *Queries) UsersB(ctx context.Context, id []int64) *UsersBBatchResults { + batch := &pgx.Batch{} + for _, a := range id { + vals := []interface{}{ + a, + } + batch.Queue(usersB, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &UsersBBatchResults{br, len(id), false} +} + +func (b *UsersBBatchResults) Query(f func(int, []int64, error)) { + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var items []int64 + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) + } + continue + } + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return err + } + items = append(items, id) + } + return rows.Err() + }() + if f != nil { + f(t, items, err) + } + } +} + +func (b *UsersBBatchResults) Close() error { + b.closed = true + return b.br.Close() +} diff --git a/internal/endtoend/testdata/output_file_names/pgx/v5/go/db_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v5/go/db_gen.go index a84825e0e1..08e71d89dc 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v5/go/db_gen.go +++ b/internal/endtoend/testdata/output_file_names/pgx/v5/go/db_gen.go @@ -15,6 +15,7 @@ type DBTX interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } func New(db DBTX) *Queries { diff --git a/internal/endtoend/testdata/output_file_names/pgx/v5/go/querier_gen.go b/internal/endtoend/testdata/output_file_names/pgx/v5/go/querier_gen.go index b30853bfc6..5eebe5018d 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v5/go/querier_gen.go +++ b/internal/endtoend/testdata/output_file_names/pgx/v5/go/querier_gen.go @@ -10,6 +10,7 @@ import ( type Querier interface { User(ctx context.Context) ([]int64, error) + UsersB(ctx context.Context, id []int64) *UsersBBatchResults } var _ Querier = (*Queries)(nil) diff --git a/internal/endtoend/testdata/output_file_names/pgx/v5/query.sql b/internal/endtoend/testdata/output_file_names/pgx/v5/query.sql index 3191419956..ef6ee90544 100644 --- a/internal/endtoend/testdata/output_file_names/pgx/v5/query.sql +++ b/internal/endtoend/testdata/output_file_names/pgx/v5/query.sql @@ -2,3 +2,7 @@ CREATE TABLE "user" (id bigserial not null); -- name: User :many SELECT "user".* FROM "user"; + +-- name: UsersB :batchmany +SELECT * FROM "user" +WHERE id = $1; diff --git a/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json b/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json index ee4a7d651f..822f727f54 100644 --- a/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json +++ b/internal/endtoend/testdata/output_file_names/stdlib/sqlc.json @@ -7,7 +7,6 @@ "schema": "query.sql", "queries": "query.sql", "emit_interface": true, - "output_batch_file_name": "batch_gen.go", "output_db_file_name": "db_gen.go", "output_models_file_name": "models_gen.go", "output_querier_file_name": "querier_gen.go"