Skip to content

Commit 7e64321

Browse files
authored
Add supports to context.Context
1 parent e3f0fdc commit 7e64321

11 files changed

+774
-29
lines changed

benchmark_go18_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
// +build go1.8
10+
11+
package mysql
12+
13+
import (
14+
"context"
15+
"database/sql"
16+
"fmt"
17+
"runtime"
18+
"testing"
19+
)
20+
21+
func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
22+
ctx, cancel := context.WithCancel(context.Background())
23+
defer cancel()
24+
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
25+
26+
tb := (*TB)(b)
27+
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
28+
defer stmt.Close()
29+
30+
b.SetParallelism(p)
31+
b.ReportAllocs()
32+
b.ResetTimer()
33+
b.RunParallel(func(pb *testing.PB) {
34+
var got string
35+
for pb.Next() {
36+
tb.check(stmt.QueryRow(1).Scan(&got))
37+
if got != "one" {
38+
b.Fatalf("query = %q; want one", got)
39+
}
40+
}
41+
})
42+
}
43+
44+
func BenchmarkQueryContext(b *testing.B) {
45+
db := initDB(b,
46+
"DROP TABLE IF EXISTS foo",
47+
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
48+
`INSERT INTO foo VALUES (1, "one")`,
49+
`INSERT INTO foo VALUES (2, "two")`,
50+
)
51+
defer db.Close()
52+
for _, p := range []int{1, 2, 3, 4} {
53+
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
54+
benchmarkQueryContext(b, db, p)
55+
})
56+
}
57+
}
58+
59+
func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
60+
ctx, cancel := context.WithCancel(context.Background())
61+
defer cancel()
62+
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
63+
64+
tb := (*TB)(b)
65+
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
66+
defer stmt.Close()
67+
68+
b.SetParallelism(p)
69+
b.ReportAllocs()
70+
b.ResetTimer()
71+
b.RunParallel(func(pb *testing.PB) {
72+
for pb.Next() {
73+
if _, err := stmt.ExecContext(ctx); err != nil {
74+
b.Fatal(err)
75+
}
76+
}
77+
})
78+
}
79+
80+
func BenchmarkExecContext(b *testing.B) {
81+
db := initDB(b,
82+
"DROP TABLE IF EXISTS foo",
83+
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
84+
`INSERT INTO foo VALUES (1, "one")`,
85+
`INSERT INTO foo VALUES (2, "two")`,
86+
)
87+
defer db.Close()
88+
for _, p := range []int{1, 2, 3, 4} {
89+
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
90+
benchmarkQueryContext(b, db, p)
91+
})
92+
}
93+
}

connection.go

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,23 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13+
"errors"
1314
"io"
1415
"net"
1516
"strconv"
1617
"strings"
18+
"sync"
1719
"time"
1820
)
1921

22+
//a copy of context.Context from Go 1.7 and later.
23+
type mysqlContext interface {
24+
Deadline() (deadline time.Time, ok bool)
25+
Done() <-chan struct{}
26+
Err() error
27+
Value(key interface{}) interface{}
28+
}
29+
2030
type mysqlConn struct {
2131
buf buffer
2232
netConn net.Conn
@@ -31,6 +41,13 @@ type mysqlConn struct {
3141
sequence uint8
3242
parseTime bool
3343
strict bool
44+
watcher chan<- mysqlContext
45+
closech chan struct{}
46+
finished chan<- struct{}
47+
48+
mu sync.Mutex // guards following fields
49+
closed error // set non-nil when conn is closed, before closech is closed
50+
canceledErr error // set non-nil if conn is canceled
3451
}
3552

3653
// Handles parameters set in DSN after the connection is established
@@ -64,7 +81,7 @@ func (mc *mysqlConn) handleParams() (err error) {
6481
}
6582

6683
func (mc *mysqlConn) Begin() (driver.Tx, error) {
67-
if mc.netConn == nil {
84+
if mc.isBroken() {
6885
errLog.Print(ErrInvalidConn)
6986
return nil, driver.ErrBadConn
7087
}
@@ -78,11 +95,11 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
7895

7996
func (mc *mysqlConn) Close() (err error) {
8097
// Makes Close idempotent
81-
if mc.netConn != nil {
98+
if !mc.isBroken() {
8299
err = mc.writeCommandPacket(comQuit)
83100
}
84101

85-
mc.cleanup()
102+
mc.cleanup(errors.New("mysql: connection is closed"))
86103

87104
return
88105
}
@@ -91,20 +108,36 @@ func (mc *mysqlConn) Close() (err error) {
91108
// function after successfully authentication, call Close instead. This function
92109
// is called before auth or on auth failure because MySQL will have already
93110
// closed the network connection.
94-
func (mc *mysqlConn) cleanup() {
111+
func (mc *mysqlConn) cleanup(err error) {
112+
if err == nil {
113+
panic("nil error")
114+
}
115+
mc.mu.Lock()
116+
defer mc.mu.Unlock()
117+
118+
if mc.closed != nil {
119+
return
120+
}
121+
95122
// Makes cleanup idempotent
96-
if mc.netConn != nil {
97-
if err := mc.netConn.Close(); err != nil {
98-
errLog.Print(err)
99-
}
100-
mc.netConn = nil
123+
mc.closed = err
124+
close(mc.closech)
125+
if mc.netConn == nil {
126+
return
127+
}
128+
if err := mc.netConn.Close(); err != nil {
129+
errLog.Print(err)
101130
}
102-
mc.cfg = nil
103-
mc.buf.nc = nil
131+
}
132+
133+
func (mc *mysqlConn) isBroken() bool {
134+
mc.mu.Lock()
135+
defer mc.mu.Unlock()
136+
return mc.closed != nil
104137
}
105138

106139
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
107-
if mc.netConn == nil {
140+
if mc.isBroken() {
108141
errLog.Print(ErrInvalidConn)
109142
return nil, driver.ErrBadConn
110143
}
@@ -258,7 +291,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
258291
}
259292

260293
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
261-
if mc.netConn == nil {
294+
if mc.isBroken() {
262295
errLog.Print(ErrInvalidConn)
263296
return nil, driver.ErrBadConn
264297
}
@@ -315,7 +348,7 @@ func (mc *mysqlConn) exec(query string) error {
315348
}
316349

317350
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
318-
if mc.netConn == nil {
351+
if mc.isBroken() {
319352
errLog.Print(ErrInvalidConn)
320353
return nil, driver.ErrBadConn
321354
}
@@ -387,3 +420,29 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
387420
}
388421
return nil, err
389422
}
423+
424+
// finish is called when the query has canceled.
425+
func (mc *mysqlConn) cancel(err error) {
426+
mc.mu.Lock()
427+
mc.canceledErr = err
428+
mc.mu.Unlock()
429+
mc.cleanup(errors.New("mysql: query canceled"))
430+
}
431+
432+
// canceled returns non-nil if the connection was closed due to context cancelation.
433+
func (mc *mysqlConn) canceled() error {
434+
mc.mu.Lock()
435+
defer mc.mu.Unlock()
436+
return mc.canceledErr
437+
}
438+
439+
// finish is called when the query has succeeded.
440+
func (mc *mysqlConn) finish() {
441+
if mc.finished == nil {
442+
return
443+
}
444+
select {
445+
case mc.finished <- struct{}{}:
446+
case <-mc.closech:
447+
}
448+
}

0 commit comments

Comments
 (0)