@@ -24,6 +24,8 @@ type testTracer struct {
24
24
traceConnectEnd func (ctx context.Context , data pgx.TraceConnectEndData )
25
25
}
26
26
27
+ type ctxKey string
28
+
27
29
func (tt * testTracer ) TraceQueryStart (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryStartData ) context.Context {
28
30
if tt .traceQueryStart != nil {
29
31
return tt .traceQueryStart (ctx , conn , data )
@@ -117,13 +119,13 @@ func TestTraceExec(t *testing.T) {
117
119
require .Equal (t , `select $1::text` , data .SQL )
118
120
require .Len (t , data .Args , 1 )
119
121
require .Equal (t , `testing` , data .Args [0 ])
120
- return context .WithValue (ctx , "fromTraceQueryStart" , "foo" )
122
+ return context .WithValue (ctx , ctxKey ( ctxKey ( "fromTraceQueryStart" )) , "foo" )
121
123
}
122
124
123
125
traceQueryEndCalled := false
124
126
tracer .traceQueryEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryEndData ) {
125
127
traceQueryEndCalled = true
126
- require .Equal (t , "foo" , ctx .Value ("fromTraceQueryStart" ))
128
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( ctxKey ( "fromTraceQueryStart" )) ))
127
129
require .Equal (t , `SELECT 1` , data .CommandTag .String ())
128
130
require .NoError (t , data .Err )
129
131
}
@@ -157,13 +159,13 @@ func TestTraceQuery(t *testing.T) {
157
159
require .Equal (t , `select $1::text` , data .SQL )
158
160
require .Len (t , data .Args , 1 )
159
161
require .Equal (t , `testing` , data .Args [0 ])
160
- return context .WithValue (ctx , "fromTraceQueryStart" , "foo" )
162
+ return context .WithValue (ctx , ctxKey ( "fromTraceQueryStart" ) , "foo" )
161
163
}
162
164
163
165
traceQueryEndCalled := false
164
166
tracer .traceQueryEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceQueryEndData ) {
165
167
traceQueryEndCalled = true
166
- require .Equal (t , "foo" , ctx .Value ("fromTraceQueryStart" ))
168
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceQueryStart" ) ))
167
169
require .Equal (t , `SELECT 1` , data .CommandTag .String ())
168
170
require .NoError (t , data .Err )
169
171
}
@@ -198,20 +200,20 @@ func TestTraceBatchNormal(t *testing.T) {
198
200
traceBatchStartCalled = true
199
201
require .NotNil (t , data .Batch )
200
202
require .Equal (t , 2 , data .Batch .Len ())
201
- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
203
+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
202
204
}
203
205
204
206
traceBatchQueryCalledCount := 0
205
207
tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
206
208
traceBatchQueryCalledCount ++
207
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
209
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
208
210
require .NoError (t , data .Err )
209
211
}
210
212
211
213
traceBatchEndCalled := false
212
214
tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
213
215
traceBatchEndCalled = true
214
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
216
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
215
217
require .NoError (t , data .Err )
216
218
}
217
219
@@ -261,20 +263,20 @@ func TestTraceBatchClose(t *testing.T) {
261
263
traceBatchStartCalled = true
262
264
require .NotNil (t , data .Batch )
263
265
require .Equal (t , 2 , data .Batch .Len ())
264
- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
266
+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
265
267
}
266
268
267
269
traceBatchQueryCalledCount := 0
268
270
tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
269
271
traceBatchQueryCalledCount ++
270
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
272
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
271
273
require .NoError (t , data .Err )
272
274
}
273
275
274
276
traceBatchEndCalled := false
275
277
tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
276
278
traceBatchEndCalled = true
277
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
279
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
278
280
require .NoError (t , data .Err )
279
281
}
280
282
@@ -312,13 +314,13 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
312
314
traceBatchStartCalled = true
313
315
require .NotNil (t , data .Batch )
314
316
require .Equal (t , 3 , data .Batch .Len ())
315
- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
317
+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
316
318
}
317
319
318
320
traceBatchQueryCalledCount := 0
319
321
tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
320
322
traceBatchQueryCalledCount ++
321
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
323
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
322
324
if traceBatchQueryCalledCount == 2 {
323
325
require .Error (t , data .Err )
324
326
} else {
@@ -329,7 +331,7 @@ func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
329
331
traceBatchEndCalled := false
330
332
tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
331
333
traceBatchEndCalled = true
332
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
334
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
333
335
require .Error (t , data .Err )
334
336
}
335
337
@@ -381,13 +383,13 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
381
383
traceBatchStartCalled = true
382
384
require .NotNil (t , data .Batch )
383
385
require .Equal (t , 3 , data .Batch .Len ())
384
- return context .WithValue (ctx , "fromTraceBatchStart" , "foo" )
386
+ return context .WithValue (ctx , ctxKey ( "fromTraceBatchStart" ) , "foo" )
385
387
}
386
388
387
389
traceBatchQueryCalledCount := 0
388
390
tracer .traceBatchQuery = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchQueryData ) {
389
391
traceBatchQueryCalledCount ++
390
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
392
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
391
393
if traceBatchQueryCalledCount == 2 {
392
394
require .Error (t , data .Err )
393
395
} else {
@@ -398,7 +400,7 @@ func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
398
400
traceBatchEndCalled := false
399
401
tracer .traceBatchEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceBatchEndData ) {
400
402
traceBatchEndCalled = true
401
- require .Equal (t , "foo" , ctx .Value ("fromTraceBatchStart" ))
403
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceBatchStart" ) ))
402
404
require .Error (t , data .Err )
403
405
}
404
406
@@ -440,13 +442,13 @@ func TestTraceCopyFrom(t *testing.T) {
440
442
traceCopyFromStartCalled = true
441
443
require .Equal (t , pgx.Identifier {"foo" }, data .TableName )
442
444
require .Equal (t , []string {"a" }, data .ColumnNames )
443
- return context .WithValue (ctx , "fromTraceCopyFromStart" , "foo" )
445
+ return context .WithValue (ctx , ctxKey ( "fromTraceCopyFromStart" ) , "foo" )
444
446
}
445
447
446
448
traceCopyFromEndCalled := false
447
449
tracer .traceCopyFromEnd = func (ctx context.Context , conn * pgx.Conn , data pgx.TraceCopyFromEndData ) {
448
450
traceCopyFromEndCalled = true
449
- require .Equal (t , "foo" , ctx .Value ("fromTraceCopyFromStart" ))
451
+ require .Equal (t , "foo" , ctx .Value (ctxKey ( "fromTraceCopyFromStart" ) ))
450
452
require .Equal (t , `COPY 2` , data .CommandTag .String ())
451
453
require .NoError (t , data .Err )
452
454
}
@@ -488,7 +490,7 @@ func TestTracePrepare(t *testing.T) {
488
490
tracePrepareStartCalled = true
489
491
require .Equal (t , `ps` , data .Name )
490
492
require .Equal (t , `select $1::text` , data .SQL )
491
- return context .WithValue (ctx , "fromTracePrepareStart" , "foo" )
493
+ return context .WithValue (ctx , ctxKey ( "fromTracePrepareStart" ) , "foo" )
492
494
}
493
495
494
496
tracePrepareEndCalled := false
@@ -530,7 +532,7 @@ func TestTraceConnect(t *testing.T) {
530
532
tracer .traceConnectStart = func (ctx context.Context , data pgx.TraceConnectStartData ) context.Context {
531
533
traceConnectStartCalled = true
532
534
require .NotNil (t , data .ConnConfig )
533
- return context .WithValue (ctx , "fromTraceConnectStart" , "foo" )
535
+ return context .WithValue (ctx , ctxKey ( "fromTraceConnectStart" ) , "foo" )
534
536
}
535
537
536
538
traceConnectEndCalled := false
0 commit comments