diff --git a/CHANGELOG.md b/CHANGELOG.md index c839c174..00049a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. +## [Unreleased] + +### Added + +### Changed + +- Connect() now retry the connection if a failure occurs and opts.Reconnect > 0. + The number of attempts is equal to opts.MaxReconnects or unlimited if + opts.MaxReconnects == 0. Connect() blocks until a connection is established, + the context is cancelled, or the number of attempts is exhausted (#436). + +### Fixed + ## [v2.3.0] - 2025-03-11 The release extends box.info responses and ConnectionPool.GetInfo return data. diff --git a/connection.go b/connection.go index d00b52a5..97f9dbbc 100644 --- a/connection.go +++ b/connection.go @@ -92,12 +92,24 @@ func (d defaultLogger) Report(event ConnLogKind, conn *Connection, v ...interfac case LogReconnectFailed: reconnects := v[0].(uint) err := v[1].(error) - log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s", - reconnects, conn.opts.MaxReconnects, conn.Addr(), err) + addr := conn.Addr() + if addr == nil { + log.Printf("tarantool: connect (%d/%d) failed: %s", + reconnects, conn.opts.MaxReconnects, err) + } else { + log.Printf("tarantool: reconnect (%d/%d) to %s failed: %s", + reconnects, conn.opts.MaxReconnects, addr, err) + } case LogLastReconnectFailed: err := v[0].(error) - log.Printf("tarantool: last reconnect to %s failed: %s, giving it up", - conn.Addr(), err) + addr := conn.Addr() + if addr == nil { + log.Printf("tarantool: last connect failed: %s, giving it up", + err) + } else { + log.Printf("tarantool: last reconnect to %s failed: %s, giving it up", + addr, err) + } case LogUnexpectedResultId: header := v[0].(Header) log.Printf("tarantool: connection %s got unexpected request ID (%d) in response "+ @@ -362,8 +374,20 @@ func Connect(ctx context.Context, dialer Dialer, opts Opts) (conn *Connection, e conn.cond = sync.NewCond(&conn.mutex) - if err = conn.createConnection(ctx); err != nil { - return nil, err + if conn.opts.Reconnect > 0 { + // We don't need these mutex.Lock()/mutex.Unlock() here, but + // runReconnects() expects mutex.Lock() to be set, so it's + // easier to add them instead of reworking runReconnects(). + conn.mutex.Lock() + err = conn.runReconnects(ctx) + conn.mutex.Unlock() + if err != nil { + return nil, err + } + } else { + if err = conn.connect(ctx); err != nil { + return nil, err + } } go conn.pinger() @@ -553,7 +577,7 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, return } -func (conn *Connection) createConnection(ctx context.Context) error { +func (conn *Connection) connect(ctx context.Context) error { var err error if conn.c == nil && conn.state == connDisconnected { if err = conn.dial(ctx); err == nil { @@ -616,19 +640,30 @@ func (conn *Connection) getDialTimeout() time.Duration { return dialTimeout } -func (conn *Connection) runReconnects() error { +func (conn *Connection) runReconnects(ctx context.Context) error { dialTimeout := conn.getDialTimeout() var reconnects uint var err error + t := time.NewTicker(conn.opts.Reconnect) + defer t.Stop() for conn.opts.MaxReconnects == 0 || reconnects <= conn.opts.MaxReconnects { - now := time.Now() - - ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) - err = conn.createConnection(ctx) + localCtx, cancel := context.WithTimeout(ctx, dialTimeout) + err = conn.connect(localCtx) cancel() if err != nil { + // The error will most likely be the one that Dialer + // returns to us due to the context being cancelled. + // Although this is not guaranteed. For example, + // if the dialer may throw another error before checking + // the context, and the context has already been + // canceled. Or the context was not canceled after + // the error was thrown, but before the context was + // checked here. + if ctx.Err() != nil { + return err + } if clientErr, ok := err.(ClientError); ok && clientErr.Code == ErrConnectionClosed { return err @@ -642,7 +677,12 @@ func (conn *Connection) runReconnects() error { reconnects++ conn.mutex.Unlock() - time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) + select { + case <-ctx.Done(): + // Since the context is cancelled, we don't need to do anything. + // Conn.connect() will return the correct error. + case <-t.C: + } conn.mutex.Lock() } @@ -656,7 +696,7 @@ func (conn *Connection) reconnectImpl(neterr error, c Conn) { if conn.opts.Reconnect > 0 { if c == conn.c { conn.closeConnection(neterr, false) - if err := conn.runReconnects(); err != nil { + if err := conn.runReconnects(context.Background()); err != nil { conn.closeConnection(err, true) } } diff --git a/tarantool_test.go b/tarantool_test.go index 2be3e793..3f1e90ef 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -3972,6 +3972,86 @@ func TestConnect_context_cancel(t *testing.T) { } } +// A dialer that rejects the first few connection requests. +type mockSlowDialer struct { + counter *int + original NetDialer +} + +func (m mockSlowDialer) Dial(ctx context.Context, opts DialOpts) (Conn, error) { + *m.counter++ + if *m.counter < 10 { + return nil, fmt.Errorf("Too early: %v", *m.counter) + } + return m.original.Dial(ctx, opts) +} + +func TestConnectIsBlocked(t *testing.T) { + const server = "127.0.0.1:3015" + testDialer := dialer + testDialer.Address = server + + inst, err := test_helpers.StartTarantool(test_helpers.StartOpts{ + Dialer: testDialer, + InitScript: "config.lua", + Listen: server, + WaitStart: 100 * time.Millisecond, + ConnectRetry: 10, + RetryTimeout: 500 * time.Millisecond, + }) + defer test_helpers.StopTarantoolWithCleanup(inst) + if err != nil { + t.Fatalf("Unable to start Tarantool: %s", err) + } + + var counter int + mockDialer := mockSlowDialer{original: testDialer, counter: &counter} + ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 100 + conn, err := Connect(ctx, mockDialer, reconnectOpts) + assert.Nil(t, err) + conn.Close() + assert.GreaterOrEqual(t, counter, 10) +} + +func TestConnectIsBlockedUntilContextExpires(t *testing.T) { + const server = "127.0.0.1:3015" + + testDialer := dialer + testDialer.Address = server + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 100 + _, err := Connect(ctx, testDialer, reconnectOpts) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "failed to dial: dial tcp 127.0.0.1:3015: i/o timeout") +} + +func TestConnectIsUnblockedAfterMaxAttempts(t *testing.T) { + const server = "127.0.0.1:3015" + + testDialer := dialer + testDialer.Address = server + + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + reconnectOpts := opts + reconnectOpts.Reconnect = 100 * time.Millisecond + reconnectOpts.MaxReconnects = 1 + _, err := Connect(ctx, testDialer, reconnectOpts) + assert.NotNil(t, err) + assert.ErrorContains(t, err, "last reconnect failed") +} + func buildSidecar(dir string) error { goPath, err := exec.LookPath("go") if err != nil {