Skip to content

Commit b1497ca

Browse files
- bug fix: when context cancelation happens during row scanning
1 parent 0f2bcc4 commit b1497ca

File tree

5 files changed

+206
-21
lines changed

5 files changed

+206
-21
lines changed

conn.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,13 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
218218

219219
// Query executes a query that returns rows, typically a SELECT.
220220
// The args are for any placeholder parameters in the query.
221-
func (c *Conn) Query(query string, args ...interface{}) (*stdSql.Rows, error) {
221+
func (c *Conn) Query(query string, args ...interface{}) (*Rows, error) {
222222
return c.QueryContext(context.Background(), query, args...)
223223
}
224224

225225
// QueryContext executes a query that returns rows, typically a SELECT.
226226
// The args are for any placeholder parameters in the query.
227-
func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdSql.Rows, error) {
227+
func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
228228

229229
// We can't use the same approach used in ExecContext because defer cancelFunc()
230230
// cancels rows.Scan.
@@ -234,7 +234,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface
234234
}
235235
}()
236236

237-
return c.conn.QueryContext(ctx, query, args...)
237+
rows, err := c.conn.QueryContext(ctx, query, args...)
238+
return &Rows{ctx: ctx, rows: rows, killerPool: c.killerPool, connectionID: c.connectionID}, err
238239
}
239240

240241
// QueryRow executes a query that is expected to return at most one row.
@@ -243,7 +244,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface
243244
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
244245
// Otherwise, the *Row's Scan scans the first selected row and discards
245246
// the rest.
246-
func (c *Conn) QueryRow(query string, args ...interface{}) *stdSql.Row {
247+
func (c *Conn) QueryRow(query string, args ...interface{}) *Row {
247248
return c.QueryRowContext(context.Background(), query, args...)
248249
}
249250

@@ -253,7 +254,7 @@ func (c *Conn) QueryRow(query string, args ...interface{}) *stdSql.Row {
253254
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
254255
// Otherwise, the *Row's Scan scans the first selected row and discards
255256
// the rest.
256-
func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *stdSql.Row {
257+
func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
257258

258259
// Since sql.Row does not export err field, this is the best we can do:
259260
defer func() {
@@ -262,5 +263,6 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf
262263
}
263264
}()
264265

265-
return c.conn.QueryRowContext(ctx, query, args...)
266+
row := c.conn.QueryRowContext(ctx, query, args...)
267+
return &Row{ctx: ctx, row: row, killerPool: c.killerPool, connectionID: c.connectionID}
266268
}

row.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package sql
2+
3+
import (
4+
"context"
5+
stdSql "database/sql"
6+
)
7+
8+
// Row is the result of calling QueryRow to select a single row.
9+
type Row struct {
10+
ctx context.Context
11+
row *stdSql.Row
12+
killerPool *stdSql.DB
13+
connectionID string
14+
}
15+
16+
// Scan copies the columns from the matched row into the values
17+
// pointed at by dest. See the documentation on Rows.Scan for details.
18+
// If more than one row matches the query,
19+
// Scan uses the first row and discards the rest. If no row matches
20+
// the query, Scan returns ErrNoRows.
21+
func (r *Row) Scan(dest ...interface{}) error {
22+
err := r.row.Scan(dest...)
23+
if r.ctx.Err() != nil {
24+
kill(r.killerPool, r.connectionID)
25+
}
26+
return err
27+
}

rows.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package sql
2+
3+
import (
4+
"context"
5+
stdSql "database/sql"
6+
)
7+
8+
// Rows is the result of a query. Its cursor starts before the first row
9+
// of the result set. Use Next to advance from row to row.
10+
type Rows struct {
11+
ctx context.Context
12+
rows *stdSql.Rows
13+
killerPool *stdSql.DB
14+
connectionID string
15+
}
16+
17+
// Unleak will release the reference to the killerPool
18+
// in order to prevent a memory leak.
19+
func (rs *Rows) Unleak() {
20+
rs.killerPool = nil
21+
rs.connectionID = ""
22+
}
23+
24+
// Close closes the Rows, preventing further enumeration. If Next is called
25+
// and returns false and there are no further result sets,
26+
// the Rows are closed automatically and it will suffice to check the
27+
// result of Err. Close is idempotent and does not affect the result of Err.
28+
func (rs *Rows) Close() error {
29+
err := rs.rows.Close()
30+
if rs.ctx.Err() != nil {
31+
kill(rs.killerPool, rs.connectionID)
32+
}
33+
rs.Unleak()
34+
return err
35+
}
36+
37+
// ColumnTypes returns column information such as column type, length,
38+
// and nullable. Some information may not be available from some drivers.
39+
func (rs *Rows) ColumnTypes() ([]*stdSql.ColumnType, error) {
40+
ct, err := rs.rows.ColumnTypes()
41+
if rs.ctx.Err() != nil {
42+
kill(rs.killerPool, rs.connectionID)
43+
}
44+
return ct, err
45+
}
46+
47+
// Columns returns the column names.
48+
// Columns returns an error if the rows are closed.
49+
func (rs *Rows) Columns() ([]string, error) {
50+
cols, err := rs.rows.Columns()
51+
if rs.ctx.Err() != nil {
52+
kill(rs.killerPool, rs.connectionID)
53+
}
54+
return cols, err
55+
}
56+
57+
// Err returns the error, if any, that was encountered during iteration.
58+
// Err may be called after an explicit or implicit Close.
59+
func (rs *Rows) Err() error {
60+
err := rs.rows.Err()
61+
if rs.ctx.Err() != nil {
62+
kill(rs.killerPool, rs.connectionID)
63+
}
64+
return err
65+
}
66+
67+
// Next prepares the next result row for reading with the Scan method. It
68+
// returns true on success, or false if there is no next result row or an error
69+
// happened while preparing it. Err should be consulted to distinguish between
70+
// the two cases.
71+
//
72+
// Every call to Scan, even the first one, must be preceded by a call to Next.
73+
func (rs *Rows) Next() bool {
74+
return rs.rows.Next()
75+
}
76+
77+
// NextResultSet prepares the next result set for reading. It reports whether
78+
// there is further result sets, or false if there is no further result set
79+
// or if there is an error advancing to it. The Err method should be consulted
80+
// to distinguish between the two cases.
81+
//
82+
// After calling NextResultSet, the Next method should always be called before
83+
// scanning. If there are further result sets they may not have rows in the result
84+
// set.
85+
func (rs *Rows) NextResultSet() bool {
86+
return rs.rows.NextResultSet()
87+
}
88+
89+
// Scan copies the columns in the current row into the values pointed
90+
// at by dest. The number of values in dest must be the same as the
91+
// number of columns in Rows.
92+
//
93+
// Scan converts columns read from the database into the following
94+
// common Go types and special types provided by the sql package:
95+
//
96+
// *string
97+
// *[]byte
98+
// *int, *int8, *int16, *int32, *int64
99+
// *uint, *uint8, *uint16, *uint32, *uint64
100+
// *bool
101+
// *float32, *float64
102+
// *interface{}
103+
// *RawBytes
104+
// *Rows (cursor value)
105+
// any type implementing Scanner (see Scanner docs)
106+
//
107+
// In the most simple case, if the type of the value from the source
108+
// column is an integer, bool or string type T and dest is of type *T,
109+
// Scan simply assigns the value through the pointer.
110+
//
111+
// Scan also converts between string and numeric types, as long as no
112+
// information would be lost. While Scan stringifies all numbers
113+
// scanned from numeric database columns into *string, scans into
114+
// numeric types are checked for overflow. For example, a float64 with
115+
// value 300 or a string with value "300" can scan into a uint16, but
116+
// not into a uint8, though float64(255) or "255" can scan into a
117+
// uint8. One exception is that scans of some float64 numbers to
118+
// strings may lose information when stringifying. In general, scan
119+
// floating point columns into *float64.
120+
//
121+
// If a dest argument has type *[]byte, Scan saves in that argument a
122+
// copy of the corresponding data. The copy is owned by the caller and
123+
// can be modified and held indefinitely. The copy can be avoided by
124+
// using an argument of type *RawBytes instead; see the documentation
125+
// for RawBytes for restrictions on its use.
126+
//
127+
// If an argument has type *interface{}, Scan copies the value
128+
// provided by the underlying driver without conversion. When scanning
129+
// from a source value of type []byte to *interface{}, a copy of the
130+
// slice is made and the caller owns the result.
131+
//
132+
// Source values of type time.Time may be scanned into values of type
133+
// *time.Time, *interface{}, *string, or *[]byte. When converting to
134+
// the latter two, time.RFC3339Nano is used.
135+
//
136+
// Source values of type bool may be scanned into types *bool,
137+
// *interface{}, *string, *[]byte, or *RawBytes.
138+
//
139+
// For scanning into *bool, the source may be true, false, 1, 0, or
140+
// string inputs parseable by strconv.ParseBool.
141+
//
142+
// Scan can also convert a cursor returned from a query, such as
143+
// "select cursor(select * from my_table) from dual", into a
144+
// *Rows value that can itself be scanned from. The parent
145+
// select query will close any cursor *Rows if the parent *Rows is closed.
146+
func (rs *Rows) Scan(dest ...interface{}) error {
147+
err := rs.rows.Scan(dest...)
148+
if rs.ctx.Err() != nil {
149+
kill(rs.killerPool, rs.connectionID)
150+
}
151+
return err
152+
}

stmt.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
// Stmt is a prepared statement.
1111
// A Stmt is safe for concurrent use by multiple goroutines.
1212
type Stmt struct {
13-
*stdSql.Stmt
13+
stmt *stdSql.Stmt
1414
killerPool *stdSql.DB
1515
connectionID string
1616
}
@@ -24,7 +24,7 @@ func (s *Stmt) Unleak() {
2424

2525
// Close closes the statement.
2626
func (s *Stmt) Close() error {
27-
err := s.Close()
27+
err := s.stmt.Close()
2828
if err != nil {
2929
return err
3030
}
@@ -65,7 +65,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (stdSql.Res
6565
}()
6666

6767
go func() {
68-
res, err := s.ExecContext(cancelCtx, args...)
68+
res, err := s.stmt.ExecContext(cancelCtx, args...)
6969
if err != nil {
7070
errChan <- err
7171
return
@@ -83,13 +83,13 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (stdSql.Res
8383

8484
// Query executes a prepared query statement with the given arguments
8585
// and returns the query results as a *Rows.
86-
func (s *Stmt) Query(args ...interface{}) (*stdSql.Rows, error) {
86+
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
8787
return s.QueryContext(context.Background(), args...)
8888
}
8989

9090
// QueryContext executes a prepared query statement with the given arguments
9191
// and returns the query results as a *Rows.
92-
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*stdSql.Rows, error) {
92+
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
9393

9494
// We can't use the same approach used in ExecContext because defer cancelFunc()
9595
// cancels rows.Scan.
@@ -99,7 +99,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*stdSql.R
9999
}
100100
}()
101101

102-
return s.QueryContext(ctx, args...)
102+
rows, err := s.stmt.QueryContext(ctx, args...)
103+
return &Rows{ctx: ctx, rows: rows, killerPool: s.killerPool, connectionID: s.connectionID}, err
103104
}
104105

105106
// QueryRow executes a prepared query statement with the given arguments.
@@ -113,7 +114,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*stdSql.R
113114
//
114115
// var name string
115116
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
116-
func (s *Stmt) QueryRow(args ...interface{}) *stdSql.Row {
117+
func (s *Stmt) QueryRow(args ...interface{}) *Row {
117118
return s.QueryRowContext(context.Background(), args...)
118119
}
119120

@@ -123,7 +124,7 @@ func (s *Stmt) QueryRow(args ...interface{}) *stdSql.Row {
123124
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
124125
// Otherwise, the *Row's Scan scans the first selected row and discards
125126
// the rest.
126-
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *stdSql.Row {
127+
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
127128

128129
// Since sql.Row does not export err field, this is the best we can do:
129130
defer func() {
@@ -132,5 +133,6 @@ func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *stdSql
132133
}
133134
}()
134135

135-
return s.QueryRowContext(ctx, args...)
136+
row := s.stmt.QueryRowContext(ctx, args...)
137+
return &Row{ctx: ctx, row: row, killerPool: s.killerPool, connectionID: s.connectionID}
136138
}

tx.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
170170
}
171171

172172
// Query executes a query that returns rows, typically a SELECT.
173-
func (tx *Tx) Query(query string, args ...interface{}) (*stdSql.Rows, error) {
173+
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
174174
return tx.QueryContext(context.Background(), query, args...)
175175
}
176176

177177
// QueryContext executes a query that returns rows, typically a SELECT.
178-
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*stdSql.Rows, error) {
178+
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
179179

180180
// We can't use the same approach used in ExecContext because defer cancelFunc()
181181
// cancels rows.Scan.
@@ -185,7 +185,8 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
185185
}
186186
}()
187187

188-
return tx.tx.QueryContext(ctx, query, args...)
188+
rows, err := tx.tx.QueryContext(ctx, query, args...)
189+
return &Rows{ctx: ctx, rows: rows, killerPool: tx.killerPool, connectionID: tx.connectionID}, err
189190
}
190191

191192
// QueryRow executes a query that is expected to return at most one row.
@@ -194,7 +195,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
194195
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
195196
// Otherwise, the *Row's Scan scans the first selected row and discards
196197
// the rest.
197-
func (tx *Tx) QueryRow(query string, args ...interface{}) *stdSql.Row {
198+
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
198199
return tx.QueryRowContext(context.Background(), query, args...)
199200
}
200201

@@ -204,7 +205,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *stdSql.Row {
204205
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
205206
// Otherwise, the *Row's Scan scans the first selected row and discards
206207
// the rest.
207-
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *stdSql.Row {
208+
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
208209

209210
// Since sql.Row does not export err field, this is the best we can do:
210211
defer func() {
@@ -213,7 +214,8 @@ func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interfa
213214
}
214215
}()
215216

216-
return tx.tx.QueryRowContext(ctx, query, args...)
217+
row := tx.tx.QueryRowContext(ctx, query, args...)
218+
return &Row{ctx: ctx, row: row, killerPool: tx.killerPool, connectionID: tx.connectionID}
217219
}
218220

219221
// Rollback aborts the transaction.

0 commit comments

Comments
 (0)