diff --git a/examples/batch/postgresql/batch.go b/examples/batch/postgresql/batch.go index e07ad01c34..09b36019d0 100644 --- a/examples/batch/postgresql/batch.go +++ b/examples/batch/postgresql/batch.go @@ -7,6 +7,7 @@ package batch import ( "context" + "errors" "time" "github.com/jackc/pgx/v4" @@ -18,8 +19,9 @@ WHERE year = $1 ` type BooksByYearBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBatchResults { @@ -31,42 +33,51 @@ func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBat batch.Queue(booksByYear, vals...) } br := q.db.SendBatch(ctx, batch) - return &BooksByYearBatchResults{br, 0} + return &BooksByYearBatchResults{br, len(year), false} } func (b *BooksByYearBatchResults) Query(f func(int, []Book, error)) { - for { - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []Book - for rows.Next() { - var i Book - if err := rows.Scan( - &i.BookID, - &i.AuthorID, - &i.Isbn, - &i.BookType, - &i.Title, - &i.Year, - &i.Available, - &i.Tags, - ); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i Book + if err := rows.Scan( + &i.BookID, + &i.AuthorID, + &i.Isbn, + &i.BookType, + &i.Title, + &i.Year, + &i.Available, + &i.Tags, + ); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *BooksByYearBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -92,8 +103,9 @@ RETURNING book_id, author_id, isbn, book_type, title, year, available, tags ` type CreateBookBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } type CreateBookParams struct { @@ -121,13 +133,20 @@ func (q *Queries) CreateBook(ctx context.Context, arg []CreateBookParams) *Creat batch.Queue(createBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &CreateBookBatchResults{br, 0} + return &CreateBookBatchResults{br, len(arg), false} } func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { - for { - row := b.br.QueryRow() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var i Book + if b.closed { + if f != nil { + f(t, i, errors.New("batch already closed")) + } + continue + } + row := b.br.QueryRow() err := row.Scan( &i.BookID, &i.AuthorID, @@ -138,17 +157,14 @@ func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { &i.Available, &i.Tags, ) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, i, err) + f(t, i, err) } - b.ind++ } } func (b *CreateBookBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -158,8 +174,9 @@ WHERE book_id = $1 ` type DeleteBookBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults { @@ -171,23 +188,27 @@ func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBat batch.Queue(deleteBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &DeleteBookBatchResults{br, 0} + return &DeleteBookBatchResults{br, len(bookID), false} } func (b *DeleteBookBatchResults) Exec(f func(int, error)) { - for { - _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } + _, err := b.br.Exec() if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *DeleteBookBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -198,8 +219,9 @@ WHERE book_id = $3 ` type UpdateBookBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } type UpdateBookParams struct { @@ -219,22 +241,26 @@ func (q *Queries) UpdateBook(ctx context.Context, arg []UpdateBookParams) *Updat batch.Queue(updateBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateBookBatchResults{br, 0} + return &UpdateBookBatchResults{br, len(arg), false} } func (b *UpdateBookBatchResults) Exec(f func(int, error)) { - for { - _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } + _, err := b.br.Exec() if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *UpdateBookBatchResults) Close() error { + b.closed = true return b.br.Close() } diff --git a/examples/batch/postgresql/db_test.go b/examples/batch/postgresql/db_test.go index 6293831efe..4455cfdf33 100644 --- a/examples/batch/postgresql/db_test.go +++ b/examples/batch/postgresql/db_test.go @@ -122,19 +122,24 @@ func TestBatchBooks(t *testing.T) { } batchDelete := dq.DeleteBook(ctx, deleteBooksParams) numDeletesProcessed := 0 + wantNumDeletesProcessed := 2 batchDelete.Exec(func(i int, err error) { - numDeletesProcessed++ - if err != nil { + if err != nil && err.Error() != "batch already closed" { t.Fatalf("error deleting book %d: %s", deleteBooksParams[i], err) } - if i == len(deleteBooksParams)-3 { + + if err == nil { + numDeletesProcessed++ + } + + if i == wantNumDeletesProcessed-1 { // close batch operation before processing all errors from delete operation if err := batchDelete.Close(); err != nil { t.Fatalf("failed to close batch operation: %s", err) } } }) - if numDeletesProcessed != 2 { - t.Fatalf("expected Close to short-circuit record processing (expected 2; got %d)", numDeletesProcessed) + if numDeletesProcessed != wantNumDeletesProcessed { + t.Fatalf("expected Close to short-circuit record processing (expected %d; got %d)", wantNumDeletesProcessed, numDeletesProcessed) } } diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 8da212ce99..f0e53ef6f0 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -445,6 +445,7 @@ func (i *importer) batchImports(filename string) fileImports { }) std["context"] = struct{}{} + std["errors"] = struct{}{} pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} return sortedImports(std, pkg) diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl index 2bd8872304..9603696cd9 100644 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ b/internal/codegen/golang/templates/pgx/batchCode.tmpl @@ -7,7 +7,8 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} type {{.MethodName}}BatchResults struct { br pgx.BatchResults - ind int + tot int + closed bool } {{if .Arg.EmitStruct}} @@ -41,71 +42,86 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDB batch.Queue({{.ConstantName}}, vals...) } br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) - return &{{.MethodName}}BatchResults{br,0} + return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false} } {{if eq .Cmd ":batchexec"}} func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { - for { - _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed"){ - break - } - if f != nil { - f(b.ind, err) - } - b.ind++ - } + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue + } + _, err := b.br.Exec() + if f != nil { + f(t, err) + } + } } {{end}} {{if eq .Cmd ":batchmany"}} func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { - for { - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - break - } - items = append(items, {{.Ret.ReturnName}}) - } - + defer b.br.Close() + for t := 0; t < b.tot; t++ { + {{- if $.EmitEmptySlices}} + items := []{{.Ret.DefineType}}{} + {{else}} + var items []{{.Ret.DefineType}} + {{end -}} + if b.closed { if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, errors.New("batch already closed")) + } + continue + } + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var {{.Ret.Name}} {{.Ret.Type}} + if err := rows.Scan({{.Ret.Scan}}); err != nil { + return err + } + items = append(items, {{.Ret.ReturnName}}) } - b.ind++ - } + return rows.Err() + }() + if f != nil { + f(t, items, err) + } + } } {{end}} {{if eq .Cmd ":batchone"}} func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { - for { - row := b.br.QueryRow() - var {{.Ret.Name}} {{.Ret.Type}} - err := row.Scan({{.Ret.Scan}}) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var {{.Ret.Name}} {{.Ret.Type}} + if b.closed { if f != nil { - f(b.ind, {{.Ret.ReturnName}}, err) + f(t, {{.Ret.Name}}, errors.New("batch already closed")) } - b.ind++ - } + continue + } + row := b.br.QueryRow() + err := row.Scan({{.Ret.Scan}}) + if f != nil { + f(t, {{.Ret.ReturnName}}, err) + } + } } {{end}} func (b *{{.MethodName}}BatchResults) Close() error { + b.closed = true return b.br.Close() } {{end}} diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go index fbd0044c3c..8676ea012b 100644 --- a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "errors" "github.com/jackc/pgx/v4" ) @@ -19,8 +20,9 @@ WHERE b = $1 ` type GetValuesBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -32,33 +34,42 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0} + return &GetValuesBatchResults{br, len(b), false} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { - for { - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []MyschemaFoo - for rows.Next() { - var i MyschemaFoo - if err := rows.Scan(&i.A, &i.B); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *GetValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -69,8 +80,9 @@ RETURNING a ` type InsertValuesBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } type InsertValuesParams struct { @@ -88,25 +100,29 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0} + return &InsertValuesBatchResults{br, len(arg), false} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { - for { - row := b.br.QueryRow() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var a sql.NullString - err := row.Scan(&a) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break + if b.closed { + if f != nil { + f(t, a, errors.New("batch already closed")) + } + continue } + row := b.br.QueryRow() + err := row.Scan(&a) if f != nil { - f(b.ind, a, err) + f(t, a, err) } - b.ind++ } } func (b *InsertValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -115,8 +131,9 @@ UPDATE myschema.foo SET a = $1, b = $2 ` type UpdateValuesBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } type UpdateValuesParams struct { @@ -134,22 +151,26 @@ func (q *Queries) UpdateValues(ctx context.Context, arg []UpdateValuesParams) *U batch.Queue(updateValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateValuesBatchResults{br, 0} + return &UpdateValuesBatchResults{br, len(arg), false} } func (b *UpdateValuesBatchResults) Exec(f func(int, error)) { - for { - _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } + _, err := b.br.Exec() if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *UpdateValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go index 797601af60..c868b91414 100644 --- a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "errors" "github.com/jackc/pgx/v4" ) @@ -19,8 +20,9 @@ WHERE b = $1 ` type GetValuesBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -32,33 +34,42 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0} + return &GetValuesBatchResults{br, len(b), false} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { - for { - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []MyschemaFoo - for rows.Next() { - var i MyschemaFoo - if err := rows.Scan(&i.A, &i.B); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *GetValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -69,8 +80,9 @@ RETURNING a ` type InsertValuesBatchResults struct { - br pgx.BatchResults - ind int + br pgx.BatchResults + tot int + closed bool } type InsertValuesParams struct { @@ -88,24 +100,28 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0} + return &InsertValuesBatchResults{br, len(arg), false} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { - for { - row := b.br.QueryRow() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var a sql.NullString - err := row.Scan(&a) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break + if b.closed { + if f != nil { + f(t, a, errors.New("batch already closed")) + } + continue } + row := b.br.QueryRow() + err := row.Scan(&a) if f != nil { - f(b.ind, a, err) + f(t, a, err) } - b.ind++ } } func (b *InsertValuesBatchResults) Close() error { + b.closed = true return b.br.Close() }