From 74882733cbb889aa562d27d4ff0768a6900d13ee Mon Sep 17 00:00:00 2001 From: Chris Lee Date: Thu, 11 Aug 2022 13:58:23 -0700 Subject: [PATCH 1/3] fix: prevent batch infinite loop with arg length --- examples/batch/postgresql/batch.go | 24 +++++++++++++++---- .../golang/templates/pgx/batchCode.tmpl | 16 ++++++++++--- .../testdata/batch/postgresql/pgx/go/batch.go | 18 +++++++++++--- .../batch_imports/postgresql/pgx/go/batch.go | 12 ++++++++-- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/examples/batch/postgresql/batch.go b/examples/batch/postgresql/batch.go index e07ad01c34..73002d7390 100644 --- a/examples/batch/postgresql/batch.go +++ b/examples/batch/postgresql/batch.go @@ -20,6 +20,7 @@ WHERE year = $1 type BooksByYearBatchResults struct { br pgx.BatchResults ind int + tot int } func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBatchResults { @@ -31,11 +32,14 @@ func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBat batch.Queue(booksByYear, vals...) } br := q.db.SendBatch(ctx, batch) - return &BooksByYearBatchResults{br, 0} + return &BooksByYearBatchResults{br, 0, len(year)} } func (b *BooksByYearBatchResults) Query(f func(int, []Book, error)) { for { + if b.ind >= b.tot { + break + } rows, err := b.br.Query() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break @@ -94,6 +98,7 @@ RETURNING book_id, author_id, isbn, book_type, title, year, available, tags type CreateBookBatchResults struct { br pgx.BatchResults ind int + tot int } type CreateBookParams struct { @@ -121,11 +126,14 @@ func (q *Queries) CreateBook(ctx context.Context, arg []CreateBookParams) *Creat batch.Queue(createBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &CreateBookBatchResults{br, 0} + return &CreateBookBatchResults{br, 0, len(arg)} } func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { for { + if b.ind >= b.tot { + break + } row := b.br.QueryRow() var i Book err := row.Scan( @@ -160,6 +168,7 @@ WHERE book_id = $1 type DeleteBookBatchResults struct { br pgx.BatchResults ind int + tot int } func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults { @@ -171,11 +180,14 @@ func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBat batch.Queue(deleteBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &DeleteBookBatchResults{br, 0} + return &DeleteBookBatchResults{br, 0, len(bookID)} } func (b *DeleteBookBatchResults) Exec(f func(int, error)) { for { + if b.ind >= b.tot { + break + } _, err := b.br.Exec() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break @@ -200,6 +212,7 @@ WHERE book_id = $3 type UpdateBookBatchResults struct { br pgx.BatchResults ind int + tot int } type UpdateBookParams struct { @@ -219,11 +232,14 @@ func (q *Queries) UpdateBook(ctx context.Context, arg []UpdateBookParams) *Updat batch.Queue(updateBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateBookBatchResults{br, 0} + return &UpdateBookBatchResults{br, 0, len(arg)} } func (b *UpdateBookBatchResults) Exec(f func(int, error)) { for { + if b.ind >= b.tot { + break + } _, err := b.br.Exec() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl index 2bd8872304..7b3002e83e 100644 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ b/internal/codegen/golang/templates/pgx/batchCode.tmpl @@ -8,6 +8,7 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} type {{.MethodName}}BatchResults struct { br pgx.BatchResults ind int + tot int } {{if .Arg.EmitStruct}} @@ -41,16 +42,19 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDB batch.Queue({{.ConstantName}}, vals...) } br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) - return &{{.MethodName}}BatchResults{br,0} + return &{{.MethodName}}BatchResults{br,0,len({{.Arg.Name}})} } {{if eq .Cmd ":batchexec"}} func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { for { - _, err := b.br.Exec() + if b.ind >= b.tot { + break + } + _, err := b.br.Exec() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed"){ break - } + } if f != nil { f(b.ind, err) } @@ -62,6 +66,9 @@ func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { {{if eq .Cmd ":batchmany"}} func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { for { + if b.ind >= b.tot { + break + } rows, err := b.br.Query() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break @@ -91,6 +98,9 @@ func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, e {{if eq .Cmd ":batchone"}} func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { for { + if b.ind >= b.tot { + break + } row := b.br.QueryRow() var {{.Ret.Name}} {{.Ret.Type}} err := row.Scan({{.Ret.Scan}}) diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go index fbd0044c3c..9fae53b570 100644 --- a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go @@ -21,6 +21,7 @@ WHERE b = $1 type GetValuesBatchResults struct { br pgx.BatchResults ind int + tot int } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -32,11 +33,14 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0} + return &GetValuesBatchResults{br, 0, len(b)} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { for { + if b.ind >= b.tot { + break + } rows, err := b.br.Query() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break @@ -71,6 +75,7 @@ RETURNING a type InsertValuesBatchResults struct { br pgx.BatchResults ind int + tot int } type InsertValuesParams struct { @@ -88,11 +93,14 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0} + return &InsertValuesBatchResults{br, 0, len(arg)} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { for { + if b.ind >= b.tot { + break + } row := b.br.QueryRow() var a sql.NullString err := row.Scan(&a) @@ -117,6 +125,7 @@ UPDATE myschema.foo SET a = $1, b = $2 type UpdateValuesBatchResults struct { br pgx.BatchResults ind int + tot int } type UpdateValuesParams struct { @@ -134,11 +143,14 @@ func (q *Queries) UpdateValues(ctx context.Context, arg []UpdateValuesParams) *U batch.Queue(updateValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateValuesBatchResults{br, 0} + return &UpdateValuesBatchResults{br, 0, len(arg)} } func (b *UpdateValuesBatchResults) Exec(f func(int, error)) { for { + if b.ind >= b.tot { + break + } _, err := b.br.Exec() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go index 797601af60..4754574ea0 100644 --- a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go @@ -21,6 +21,7 @@ WHERE b = $1 type GetValuesBatchResults struct { br pgx.BatchResults ind int + tot int } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -32,11 +33,14 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0} + return &GetValuesBatchResults{br, 0, len(b)} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { for { + if b.ind >= b.tot { + break + } rows, err := b.br.Query() if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { break @@ -71,6 +75,7 @@ RETURNING a type InsertValuesBatchResults struct { br pgx.BatchResults ind int + tot int } type InsertValuesParams struct { @@ -88,11 +93,14 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0} + return &InsertValuesBatchResults{br, 0, len(arg)} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { for { + if b.ind >= b.tot { + break + } row := b.br.QueryRow() var a sql.NullString err := row.Scan(&a) From 7c6ac1074531b2caa485ed22e059da0f0315647b Mon Sep 17 00:00:00 2001 From: Chris Lee Date: Fri, 2 Sep 2022 19:18:31 -0700 Subject: [PATCH 2/3] rewrite batch generated code -- closed bool, remove struct index --- examples/batch/postgresql/batch.go | 144 ++++++++++-------- internal/codegen/golang/imports.go | 1 + .../golang/templates/pgx/batchCode.tmpl | 112 +++++++------- .../testdata/batch/postgresql/pgx/go/batch.go | 101 ++++++------ .../batch_imports/postgresql/pgx/go/batch.go | 76 ++++----- 5 files changed, 234 insertions(+), 200 deletions(-) diff --git a/examples/batch/postgresql/batch.go b/examples/batch/postgresql/batch.go index 73002d7390..09b36019d0 100644 --- a/examples/batch/postgresql/batch.go +++ b/examples/batch/postgresql/batch.go @@ -7,6 +7,7 @@ package batch import ( "context" + "errors" "time" "github.com/jackc/pgx/v4" @@ -18,9 +19,9 @@ WHERE year = $1 ` type BooksByYearBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBatchResults { @@ -32,45 +33,51 @@ func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBat batch.Queue(booksByYear, vals...) } br := q.db.SendBatch(ctx, batch) - return &BooksByYearBatchResults{br, 0, len(year)} + return &BooksByYearBatchResults{br, len(year), false} } func (b *BooksByYearBatchResults) Query(f func(int, []Book, error)) { - for { - if b.ind >= b.tot { - break - } - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []Book - for rows.Next() { - var i Book - if err := rows.Scan( - &i.BookID, - &i.AuthorID, - &i.Isbn, - &i.BookType, - &i.Title, - &i.Year, - &i.Available, - &i.Tags, - ); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i Book + if err := rows.Scan( + &i.BookID, + &i.AuthorID, + &i.Isbn, + &i.BookType, + &i.Title, + &i.Year, + &i.Available, + &i.Tags, + ); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *BooksByYearBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -96,9 +103,9 @@ RETURNING book_id, author_id, isbn, book_type, title, year, available, tags ` type CreateBookBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } type CreateBookParams struct { @@ -126,16 +133,20 @@ func (q *Queries) CreateBook(ctx context.Context, arg []CreateBookParams) *Creat batch.Queue(createBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &CreateBookBatchResults{br, 0, len(arg)} + return &CreateBookBatchResults{br, len(arg), false} } func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var i Book + if b.closed { + if f != nil { + f(t, i, errors.New("batch already closed")) + } + continue } row := b.br.QueryRow() - var i Book err := row.Scan( &i.BookID, &i.AuthorID, @@ -146,17 +157,14 @@ func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { &i.Available, &i.Tags, ) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, i, err) + f(t, i, err) } - b.ind++ } } func (b *CreateBookBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -166,9 +174,9 @@ WHERE book_id = $1 ` type DeleteBookBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults { @@ -180,26 +188,27 @@ func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBat batch.Queue(deleteBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &DeleteBookBatchResults{br, 0, len(bookID)} + return &DeleteBookBatchResults{br, len(bookID), false} } func (b *DeleteBookBatchResults) Exec(f func(int, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *DeleteBookBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -210,9 +219,9 @@ WHERE book_id = $3 ` type UpdateBookBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } type UpdateBookParams struct { @@ -232,25 +241,26 @@ func (q *Queries) UpdateBook(ctx context.Context, arg []UpdateBookParams) *Updat batch.Queue(updateBook, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateBookBatchResults{br, 0, len(arg)} + return &UpdateBookBatchResults{br, len(arg), false} } func (b *UpdateBookBatchResults) Exec(f func(int, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *UpdateBookBatchResults) Close() error { + b.closed = true return b.br.Close() } diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 8da212ce99..f0e53ef6f0 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -445,6 +445,7 @@ func (i *importer) batchImports(filename string) fileImports { }) std["context"] = struct{}{} + std["errors"] = struct{}{} pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} return sortedImports(std, pkg) diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl index 7b3002e83e..9603696cd9 100644 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ b/internal/codegen/golang/templates/pgx/batchCode.tmpl @@ -7,8 +7,8 @@ const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} type {{.MethodName}}BatchResults struct { br pgx.BatchResults - ind int tot int + closed bool } {{if .Arg.EmitStruct}} @@ -42,80 +42,86 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDB batch.Queue({{.ConstantName}}, vals...) } br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) - return &{{.MethodName}}BatchResults{br,0,len({{.Arg.Name}})} + return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false} } {{if eq .Cmd ":batchexec"}} func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { - for { - if b.ind >= b.tot { - break - } - _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed"){ - break - } - if f != nil { - f(b.ind, err) - } - b.ind++ - } + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue + } + _, err := b.br.Exec() + if f != nil { + f(t, err) + } + } } {{end}} {{if eq .Cmd ":batchmany"}} func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { - for { - if b.ind >= b.tot { - break - } - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - break - } - items = append(items, {{.Ret.ReturnName}}) - } - + defer b.br.Close() + for t := 0; t < b.tot; t++ { + {{- if $.EmitEmptySlices}} + items := []{{.Ret.DefineType}}{} + {{else}} + var items []{{.Ret.DefineType}} + {{end -}} + if b.closed { if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, errors.New("batch already closed")) } - b.ind++ - } + continue + } + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var {{.Ret.Name}} {{.Ret.Type}} + if err := rows.Scan({{.Ret.Scan}}); err != nil { + return err + } + items = append(items, {{.Ret.ReturnName}}) + } + return rows.Err() + }() + if f != nil { + f(t, items, err) + } + } } {{end}} {{if eq .Cmd ":batchone"}} func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { - for { - if b.ind >= b.tot { - break - } - row := b.br.QueryRow() - var {{.Ret.Name}} {{.Ret.Type}} - err := row.Scan({{.Ret.Scan}}) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var {{.Ret.Name}} {{.Ret.Type}} + if b.closed { if f != nil { - f(b.ind, {{.Ret.ReturnName}}, err) + f(t, {{.Ret.Name}}, errors.New("batch already closed")) } - b.ind++ - } + continue + } + row := b.br.QueryRow() + err := row.Scan({{.Ret.Scan}}) + if f != nil { + f(t, {{.Ret.ReturnName}}, err) + } + } } {{end}} func (b *{{.MethodName}}BatchResults) Close() error { + b.closed = true return b.br.Close() } {{end}} diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go index 9fae53b570..8676ea012b 100644 --- a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "errors" "github.com/jackc/pgx/v4" ) @@ -19,9 +20,9 @@ WHERE b = $1 ` type GetValuesBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -33,36 +34,42 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0, len(b)} + return &GetValuesBatchResults{br, len(b), false} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { - for { - if b.ind >= b.tot { - break - } - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []MyschemaFoo - for rows.Next() { - var i MyschemaFoo - if err := rows.Scan(&i.A, &i.B); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *GetValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -73,9 +80,9 @@ RETURNING a ` type InsertValuesBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } type InsertValuesParams struct { @@ -93,28 +100,29 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0, len(arg)} + return &InsertValuesBatchResults{br, len(arg), false} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var a sql.NullString + if b.closed { + if f != nil { + f(t, a, errors.New("batch already closed")) + } + continue } row := b.br.QueryRow() - var a sql.NullString err := row.Scan(&a) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, a, err) + f(t, a, err) } - b.ind++ } } func (b *InsertValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -123,9 +131,9 @@ UPDATE myschema.foo SET a = $1, b = $2 ` type UpdateValuesBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } type UpdateValuesParams struct { @@ -143,25 +151,26 @@ func (q *Queries) UpdateValues(ctx context.Context, arg []UpdateValuesParams) *U batch.Queue(updateValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &UpdateValuesBatchResults{br, 0, len(arg)} + return &UpdateValuesBatchResults{br, len(arg), false} } func (b *UpdateValuesBatchResults) Exec(f func(int, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, errors.New("batch already closed")) + } + continue } _, err := b.br.Exec() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, err) + f(t, err) } - b.ind++ } } func (b *UpdateValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go index 4754574ea0..c868b91414 100644 --- a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go @@ -8,6 +8,7 @@ package querytest import ( "context" "database/sql" + "errors" "github.com/jackc/pgx/v4" ) @@ -19,9 +20,9 @@ WHERE b = $1 ` type GetValuesBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { @@ -33,36 +34,42 @@ func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBa batch.Queue(getValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &GetValuesBatchResults{br, 0, len(b)} + return &GetValuesBatchResults{br, len(b), false} } func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { - for { - if b.ind >= b.tot { - break - } - rows, err := b.br.Query() - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } - defer rows.Close() + defer b.br.Close() + for t := 0; t < b.tot; t++ { var items []MyschemaFoo - for rows.Next() { - var i MyschemaFoo - if err := rows.Scan(&i.A, &i.B); err != nil { - break + if b.closed { + if f != nil { + f(t, items, errors.New("batch already closed")) } - items = append(items, i) + continue } - + err := func() error { + rows, err := b.br.Query() + defer rows.Close() + if err != nil { + return err + } + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + return err + } + items = append(items, i) + } + return rows.Err() + }() if f != nil { - f(b.ind, items, rows.Err()) + f(t, items, err) } - b.ind++ } } func (b *GetValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } @@ -73,9 +80,9 @@ RETURNING a ` type InsertValuesBatchResults struct { - br pgx.BatchResults - ind int - tot int + br pgx.BatchResults + tot int + closed bool } type InsertValuesParams struct { @@ -93,27 +100,28 @@ func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *I batch.Queue(insertValues, vals...) } br := q.db.SendBatch(ctx, batch) - return &InsertValuesBatchResults{br, 0, len(arg)} + return &InsertValuesBatchResults{br, len(arg), false} } func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { - for { - if b.ind >= b.tot { - break + defer b.br.Close() + for t := 0; t < b.tot; t++ { + var a sql.NullString + if b.closed { + if f != nil { + f(t, a, errors.New("batch already closed")) + } + continue } row := b.br.QueryRow() - var a sql.NullString err := row.Scan(&a) - if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { - break - } if f != nil { - f(b.ind, a, err) + f(t, a, err) } - b.ind++ } } func (b *InsertValuesBatchResults) Close() error { + b.closed = true return b.br.Close() } From 6d02a0b4e06ae328ef45e8a249b05c6268072e37 Mon Sep 17 00:00:00 2001 From: Chris Lee Date: Wed, 7 Sep 2022 14:24:13 -0700 Subject: [PATCH 3/3] fix db_test.go in examples --- examples/batch/postgresql/db_test.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/batch/postgresql/db_test.go b/examples/batch/postgresql/db_test.go index 6293831efe..4455cfdf33 100644 --- a/examples/batch/postgresql/db_test.go +++ b/examples/batch/postgresql/db_test.go @@ -122,19 +122,24 @@ func TestBatchBooks(t *testing.T) { } batchDelete := dq.DeleteBook(ctx, deleteBooksParams) numDeletesProcessed := 0 + wantNumDeletesProcessed := 2 batchDelete.Exec(func(i int, err error) { - numDeletesProcessed++ - if err != nil { + if err != nil && err.Error() != "batch already closed" { t.Fatalf("error deleting book %d: %s", deleteBooksParams[i], err) } - if i == len(deleteBooksParams)-3 { + + if err == nil { + numDeletesProcessed++ + } + + if i == wantNumDeletesProcessed-1 { // close batch operation before processing all errors from delete operation if err := batchDelete.Close(); err != nil { t.Fatalf("failed to close batch operation: %s", err) } } }) - if numDeletesProcessed != 2 { - t.Fatalf("expected Close to short-circuit record processing (expected 2; got %d)", numDeletesProcessed) + if numDeletesProcessed != wantNumDeletesProcessed { + t.Fatalf("expected Close to short-circuit record processing (expected %d; got %d)", wantNumDeletesProcessed, numDeletesProcessed) } }