Skip to content

Commit 8966bab

Browse files
committed
test: Implement TestContextCancelQueryWhileScan
1 parent c0f6b44 commit 8966bab

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

driver_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"io/ioutil"
2020
"log"
2121
"math"
22+
"math/rand"
2223
"net"
2324
"net/url"
2425
"os"
@@ -2938,3 +2939,57 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
29382939
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
29392940
})
29402941
}
2942+
2943+
// TestContextCancelQueryWhileScan checks for race conditions that arise when
2944+
// a query context is canceled while a user is calling rows.Scan(). The code
2945+
// is based on database/sql TestIssue18429.
2946+
// See https://github.com/golang/go/issues/23519
2947+
func TestContextCancelQueryWhileScan(t *testing.T) {
2948+
const blob = "0123456789abcdef"
2949+
const sqlQuery = `SELECT id, value FROM test WHERE SLEEP(?) = 0`
2950+
const contextRaceIterations = 1000
2951+
const milliWait = 30
2952+
const blobSize = 64 * 1024
2953+
const insertRows = 64
2954+
2955+
largeBlob := strings.Repeat(blob, blobSize/len(blob))
2956+
2957+
runTests(t, dsn, func(dbt *DBTest) {
2958+
dbt.mustExec("CREATE TABLE test (id int, value MEDIUMBLOB) CHARACTER SET utf8")
2959+
for i := 0; i < insertRows; i++ {
2960+
dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, largeBlob)
2961+
}
2962+
2963+
sem := make(chan bool, 20)
2964+
var wg sync.WaitGroup
2965+
for i := 0; i < contextRaceIterations; i++ {
2966+
sem <- true
2967+
wg.Add(1)
2968+
go func() {
2969+
defer func() {
2970+
<-sem
2971+
wg.Done()
2972+
}()
2973+
qwait := float64(time.Duration(rand.Intn(milliWait))*time.Millisecond) / float64(time.Second)
2974+
2975+
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(rand.Intn(milliWait))*time.Millisecond)
2976+
defer cancel()
2977+
2978+
rows, _ := dbt.db.QueryContext(ctx, sqlQuery, qwait)
2979+
if rows != nil {
2980+
var b int
2981+
var n string
2982+
for rows.Next() {
2983+
if rows.Scan(&b, &n) == nil {
2984+
if len(n) != blobSize {
2985+
t.Fatal("mismatch in read buffer")
2986+
}
2987+
}
2988+
}
2989+
rows.Close()
2990+
}
2991+
}()
2992+
}
2993+
wg.Wait()
2994+
})
2995+
}

0 commit comments

Comments
 (0)