diff --git a/CHANGELOG.md b/CHANGELOG.md index ad693e0f0..00c2ded38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. - Support decimal type in msgpack (#96) - Support datetime type in msgpack (#118) - Prepared SQL statements (#117) +- Context support for request objects (#48) ### Changed diff --git a/connection.go b/connection.go index 6de1e9d01..6a1829837 100644 --- a/connection.go +++ b/connection.go @@ -5,6 +5,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -125,8 +126,11 @@ type Connection struct { c net.Conn mutex sync.Mutex // Schema contains schema loaded on connection. - Schema *Schema + Schema *Schema + // requestId contains the last request ID for requests with nil context. requestId uint32 + // contextRequestId contains the last request ID for requests with context. + contextRequestId uint32 // Greeting contains first message sent by Tarantool. Greeting *Greeting @@ -143,16 +147,56 @@ type Connection struct { var _ = Connector(&Connection{}) // Check compatibility with connector interface. +type futureList struct { + first *Future + last **Future +} + +func (list *futureList) findFuture(reqid uint32, fetch bool) *Future { + root := &list.first + for { + fut := *root + if fut == nil { + return nil + } + if fut.requestId == reqid { + if fetch { + *root = fut.next + if fut.next == nil { + list.last = root + } else { + fut.next = nil + } + } + return fut + } + root = &fut.next + } +} + +func (list *futureList) addFuture(fut *Future) { + *list.last = fut + list.last = &fut.next +} + +func (list *futureList) clear(err error, conn *Connection) { + fut := list.first + list.first = nil + list.last = &list.first + for fut != nil { + fut.SetError(err) + conn.markDone(fut) + fut, fut.next = fut.next, nil + } +} + type connShard struct { - rmut sync.Mutex - requests [requestsMap]struct { - first *Future - last **Future - } - bufmut sync.Mutex - buf smallWBuf - enc *msgpack.Encoder - _pad [16]uint64 //nolint: unused,structcheck + rmut sync.Mutex + requests [requestsMap]futureList + requestsWithCtx [requestsMap]futureList + bufmut sync.Mutex + buf smallWBuf + enc *msgpack.Encoder } // Greeting is a message sent by Tarantool on connect. @@ -167,6 +211,11 @@ type Opts struct { // push messages are received. If Timeout is zero, any request can be // blocked infinitely. // Also used to setup net.TCPConn.Set(Read|Write)Deadline. + // + // Pay attention, when using contexts with request objects, + // the timeout option for Connection does not affect the lifetime + // of the request. For those purposes use context.WithTimeout() as + // the root context. Timeout time.Duration // Timeout between reconnect attempts. If Reconnect is zero, no // reconnect attempts will be made. @@ -262,12 +311,13 @@ type SslOpts struct { // and will not finish to make attempts on authorization failures. func Connect(addr string, opts Opts) (conn *Connection, err error) { conn = &Connection{ - addr: addr, - requestId: 0, - Greeting: &Greeting{}, - control: make(chan struct{}), - opts: opts, - dec: msgpack.NewDecoder(&smallBuf{}), + addr: addr, + requestId: 0, + contextRequestId: 1, + Greeting: &Greeting{}, + control: make(chan struct{}), + opts: opts, + dec: msgpack.NewDecoder(&smallBuf{}), } maxprocs := uint32(runtime.GOMAXPROCS(-1)) if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 { @@ -283,8 +333,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { conn.shard = make([]connShard, conn.opts.Concurrency) for i := range conn.shard { shard := &conn.shard[i] - for j := range shard.requests { - shard.requests[j].last = &shard.requests[j].first + requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx} + for _, requests := range requestsLists { + for j := range requests { + requests[j].last = &requests[j].first + } } } @@ -387,6 +440,13 @@ func (conn *Connection) Handle() interface{} { return conn.opts.Handle } +func (conn *Connection) cancelFuture(fut *Future, err error) { + if fut = conn.fetchFuture(fut.requestId); fut != nil { + fut.SetError(err) + conn.markDone(fut) + } +} + func (conn *Connection) dial() (err error) { var connection net.Conn network := "tcp" @@ -580,15 +640,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) } for i := range conn.shard { conn.shard[i].buf.Reset() - requests := &conn.shard[i].requests - for pos := range requests { - fut := requests[pos].first - requests[pos].first = nil - requests[pos].last = &requests[pos].first - for fut != nil { - fut.SetError(neterr) - conn.markDone(fut) - fut, fut.next = fut.next, nil + requestsLists := []*[requestsMap]futureList{&conn.shard[i].requests, &conn.shard[i].requestsWithCtx} + for _, requests := range requestsLists { + for pos := range requests { + requests[pos].clear(neterr, conn) } } } @@ -721,7 +776,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } } -func (conn *Connection) newFuture() (fut *Future) { +func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { select { @@ -736,7 +791,7 @@ func (conn *Connection) newFuture() (fut *Future) { return } } - fut.requestId = conn.nextRequestId() + fut.requestId = conn.nextRequestId(ctx != nil) shardn := fut.requestId & (conn.opts.Concurrency - 1) shard := &conn.shard[shardn] shard.rmut.Lock() @@ -761,11 +816,20 @@ func (conn *Connection) newFuture() (fut *Future) { return } pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - *pair.last = fut - pair.last = &fut.next - if conn.opts.Timeout > 0 { - fut.timeout = time.Since(epoch) + conn.opts.Timeout + if ctx != nil { + select { + case <-ctx.Done(): + fut.SetError(fmt.Errorf("context is done")) + shard.rmut.Unlock() + return + default: + } + shard.requestsWithCtx[pos].addFuture(fut) + } else { + shard.requests[pos].addFuture(fut) + if conn.opts.Timeout > 0 { + fut.timeout = time.Since(epoch) + conn.opts.Timeout + } } shard.rmut.Unlock() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait { @@ -785,12 +849,43 @@ func (conn *Connection) newFuture() (fut *Future) { return } +// This method removes a future from the internal queue if the context +// is "done" before the response is come. Such select logic is inspired +// from this thread: https://groups.google.com/g/golang-dev/c/jX4oQEls3uk +func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) { + select { + case <-fut.done: + default: + select { + case <-ctx.Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + default: + select { + case <-fut.done: + case <-ctx.Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + } + } + } +} + func (conn *Connection) send(req Request) *Future { - fut := conn.newFuture() + fut := conn.newFuture(req.Ctx()) if fut.ready == nil { return fut } + if req.Ctx() != nil { + select { + case <-req.Ctx().Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + return fut + default: + } + } conn.putFuture(fut, req) + if req.Ctx() != nil { + go conn.contextWatchdog(fut, req.Ctx()) + } return fut } @@ -877,25 +972,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) { func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future { shard := &conn.shard[reqid&(conn.opts.Concurrency-1)] pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - root := &pair.first - for { - fut := *root - if fut == nil { - return nil - } - if fut.requestId == reqid { - if fetch { - *root = fut.next - if fut.next == nil { - pair.last = root - } else { - fut.next = nil - } - } - return fut - } - root = &fut.next + // futures with even requests id belong to requests list with nil context + if reqid%2 == 0 { + return shard.requests[pos].findFuture(reqid, fetch) + } else { + return shard.requestsWithCtx[pos].findFuture(reqid, fetch) } } @@ -984,8 +1065,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) { return } -func (conn *Connection) nextRequestId() (requestId uint32) { - return atomic.AddUint32(&conn.requestId, 1) +func (conn *Connection) nextRequestId(context bool) (requestId uint32) { + if context { + return atomic.AddUint32(&conn.contextRequestId, 2) + } else { + return atomic.AddUint32(&conn.requestId, 2) + } } // Do performs a request asynchronously on the connection. @@ -1000,6 +1085,15 @@ func (conn *Connection) Do(req Request) *Future { return fut } } + if req.Ctx() != nil { + select { + case <-req.Ctx().Done(): + fut := NewFuture() + fut.SetError(fmt.Errorf("context is done")) + return fut + default: + } + } return conn.send(req) } diff --git a/example_test.go b/example_test.go index 65dc971a0..cd4c7874c 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package tarantool_test import ( + "context" "fmt" "time" @@ -691,3 +692,33 @@ func ExampleConnection_NewPrepared() { fmt.Printf("Failed to prepare") } } + +// To pass contexts to request objects, use the Context() method. +// Pay attention that when using context with request objects, +// the timeout option for Connection will not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func ExamplePingRequest_Context() { + conn := example_connect() + defer conn.Close() + + timeout := time.Nanosecond + + // this way you may set the common timeout for requests with context + rootCtx, cancelRoot := context.WithTimeout(context.Background(), timeout) + defer cancelRoot() + + // this context will be canceled with the root after commonTimeout + ctx, cancel := context.WithCancel(rootCtx) + defer cancel() + + req := tarantool.NewPingRequest().Context(ctx) + + // Ping a Tarantool instance to check connection. + resp, err := conn.Do(req).Get() + fmt.Println("Ping Resp", resp) + fmt.Println("Ping Error", err) + // Output: + // Ping Resp + // Ping Error context is done +} diff --git a/prepared.go b/prepared.go index 9508f0546..6a41538ed 100644 --- a/prepared.go +++ b/prepared.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "fmt" "gopkg.in/vmihailenco/msgpack.v2" @@ -58,6 +59,17 @@ func (req *PrepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillPrepare(enc, req.expr) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *PrepareRequest) Context(ctx context.Context) *PrepareRequest { + req.ctx = ctx + return req +} + // UnprepareRequest helps you to create an unprepare request object for // execution by a Connection. type UnprepareRequest struct { @@ -83,6 +95,17 @@ func (req *UnprepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) erro return fillUnprepare(enc, *req.stmt) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UnprepareRequest) Context(ctx context.Context) *UnprepareRequest { + req.ctx = ctx + return req +} + // ExecutePreparedRequest helps you to create an execute prepared request // object for execution by a Connection. type ExecutePreparedRequest struct { @@ -117,6 +140,17 @@ func (req *ExecutePreparedRequest) Body(res SchemaResolver, enc *msgpack.Encoder return fillExecutePrepared(enc, *req.stmt, req.args) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ExecutePreparedRequest) Context(ctx context.Context) *ExecutePreparedRequest { + req.ctx = ctx + return req +} + func fillPrepare(enc *msgpack.Encoder, expr string) error { enc.EncodeMapLen(1) enc.EncodeUint64(KeySQLText) diff --git a/request.go b/request.go index a83094145..c708b79b4 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "errors" "reflect" "strings" @@ -537,6 +538,8 @@ type Request interface { Code() int32 // Body fills an encoder with a request body. Body(resolver SchemaResolver, enc *msgpack.Encoder) error + // Ctx returns a context of the request. + Ctx() context.Context } // ConnectedRequest is an interface that provides the info about a Connection @@ -549,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + ctx context.Context } // Code returns a IPROTO code for the request. @@ -556,6 +560,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Ctx returns a context of the request. +func (req *baseRequest) Ctx() context.Context { + return req.ctx +} + type spaceRequest struct { baseRequest space interface{} @@ -613,6 +622,17 @@ func (req *PingRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillPing(enc) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *PingRequest) Context(ctx context.Context) *PingRequest { + req.ctx = ctx + return req +} + // SelectRequest allows you to create a select request object for execution // by a Connection. type SelectRequest struct { @@ -683,6 +703,17 @@ func (req *SelectRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillSelect(enc, spaceNo, indexNo, req.offset, req.limit, req.iterator, req.key) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *SelectRequest) Context(ctx context.Context) *SelectRequest { + req.ctx = ctx + return req +} + // InsertRequest helps you to create an insert request object for execution // by a Connection. type InsertRequest struct { @@ -716,6 +747,17 @@ func (req *InsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillInsert(enc, spaceNo, req.tuple) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *InsertRequest) Context(ctx context.Context) *InsertRequest { + req.ctx = ctx + return req +} + // ReplaceRequest helps you to create a replace request object for execution // by a Connection. type ReplaceRequest struct { @@ -749,6 +791,17 @@ func (req *ReplaceRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillInsert(enc, spaceNo, req.tuple) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ReplaceRequest) Context(ctx context.Context) *ReplaceRequest { + req.ctx = ctx + return req +} + // DeleteRequest helps you to create a delete request object for execution // by a Connection. type DeleteRequest struct { @@ -789,6 +842,17 @@ func (req *DeleteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillDelete(enc, spaceNo, indexNo, req.key) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *DeleteRequest) Context(ctx context.Context) *DeleteRequest { + req.ctx = ctx + return req +} + // UpdateRequest helps you to create an update request object for execution // by a Connection. type UpdateRequest struct { @@ -840,6 +904,17 @@ func (req *UpdateRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpdate(enc, spaceNo, indexNo, req.key, req.ops) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UpdateRequest) Context(ctx context.Context) *UpdateRequest { + req.ctx = ctx + return req +} + // UpsertRequest helps you to create an upsert request object for execution // by a Connection. type UpsertRequest struct { @@ -884,6 +959,17 @@ func (req *UpsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpsert(enc, spaceNo, req.tuple, req.ops) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *UpsertRequest) Context(ctx context.Context) *UpsertRequest { + req.ctx = ctx + return req +} + // CallRequest helps you to create a call request object for execution // by a Connection. type CallRequest struct { @@ -915,6 +1001,17 @@ func (req *CallRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillCall(enc, req.function, req.args) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *CallRequest) Context(ctx context.Context) *CallRequest { + req.ctx = ctx + return req +} + // NewCall16Request returns a new empty Call16Request. It uses request code for // Tarantool 1.6. // Deprecated since Tarantool 1.7.2. @@ -961,6 +1058,17 @@ func (req *EvalRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillEval(enc, req.expr, req.args) } +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *EvalRequest) Context(ctx context.Context) *EvalRequest { + req.ctx = ctx + return req +} + // ExecuteRequest helps you to create an execute request object for execution // by a Connection. type ExecuteRequest struct { @@ -989,3 +1097,14 @@ func (req *ExecuteRequest) Args(args interface{}) *ExecuteRequest { func (req *ExecuteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillExecute(enc, req.expr, req.args) } + +// Context sets a passed context to the request. +// +// Pay attention that when using context with request objects, +// the timeout option for Connection does not affect the lifetime +// of the request. For those purposes use context.WithTimeout() as +// the root context. +func (req *ExecuteRequest) Context(ctx context.Context) *ExecuteRequest { + req.ctx = ctx + return req +} diff --git a/tarantool_test.go b/tarantool_test.go index 06771338c..f5360ba6b 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -1,10 +1,12 @@ package tarantool_test import ( + "context" "fmt" "log" "os" "reflect" + "runtime" "strings" "sync" "testing" @@ -100,16 +102,45 @@ func BenchmarkClientSerialRequestObject(b *testing.B) { if err != nil { b.Error(err) } + req := NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkClientSerialRequestObjectWithContext(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err = conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Error(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + b.ResetTimer() + for i := 0; i < b.N; i++ { req := NewSelectRequest(spaceNo). Index(indexNo). - Offset(0). Limit(1). Iterator(IterEq). - Key([]interface{}{uint(1111)}) + Key([]interface{}{uint(1111)}). + Context(ctx) _, err := conn.Do(req).Get() if err != nil { b.Error(err) @@ -342,6 +373,131 @@ func BenchmarkClientParallel(b *testing.B) { }) } +func benchmarkClientParallelRequestObject(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func benchmarkClientParallelRequestObjectWithContext(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}). + Context(ctx) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func benchmarkClientParallelRequestObjectMixed(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + reqWithCtx := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}). + Context(ctx) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(reqWithCtx).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func BenchmarkClientParallelRequestObject(b *testing.B) { + multipliers := []int{10, 50, 500, 1000} + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + for _, m := range multipliers { + goroutinesNum := runtime.GOMAXPROCS(0) * m + + b.Run(fmt.Sprintf("Plain %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObject(m, b) + }) + + b.Run(fmt.Sprintf("With Context %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObjectWithContext(m, b) + }) + + b.Run(fmt.Sprintf("Mixed %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObjectMixed(m, b) + }) + } +} + func BenchmarkClientParallelMassive(b *testing.B) { conn := test_helpers.ConnectWithValidation(b, server, opts) defer conn.Close() @@ -2081,6 +2237,59 @@ func TestClientRequestObjects(t *testing.T) { } } +func TestClientRequestObjectsWithNilContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + req := NewPingRequest().Context(nil) //nolint + resp, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err.Error()) + } + if resp == nil { + t.Fatalf("Response is nil after Ping") + } + if len(resp.Data) != 0 { + t.Errorf("Response Body len != 0") + } +} + +func TestClientRequestObjectsWithPassedCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + cancel() + resp, err := conn.Do(req).Get() + if err.Error() != "context is done" { + t.Fatalf("Failed to catch an error from done context") + } + if resp != nil { + t.Fatalf("Response is not nil after the occured error") + } +} + +func TestClientRequestObjectsWithContext(t *testing.T) { + var err error + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + fut := conn.Do(req) + cancel() + resp, err := fut.Get() + if resp != nil { + t.Fatalf("response must be nil") + } + if err == nil { + t.Fatalf("catched nil error") + } + if err.Error() != "context is done" { + t.Fatalf("wrong error catched: %v", err) + } +} + func TestComplexStructs(t *testing.T) { var err error diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 00674a3a7..630d57e66 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -1,6 +1,8 @@ package test_helpers import ( + "context" + "github.com/tarantool/go-tarantool" "gopkg.in/vmihailenco/msgpack.v2" ) @@ -23,3 +25,7 @@ func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *msgpack. func (sr *StrangerRequest) Conn() *tarantool.Connection { return &tarantool.Connection{} } + +func (sr *StrangerRequest) Ctx() context.Context { + return nil +}