diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index de7a644ea..d1103a31e 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -1,7 +1,18 @@ package pool -import "time" +import ( + "net" + "time" +) func (cn *Conn) SetCreatedAt(tm time.Time) { cn.createdAt = tm } + +func (cn *Conn) NetConn() net.Conn { + return cn.netConn +} + +func MaxBadConnRetries() int { + return maxBadConnRetries +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 577923a79..19df4f44b 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -11,6 +11,10 @@ import ( "github.com/go-redis/redis/v8/internal" ) +// maxBadConnRetries is the maximum number of attempts to obtain a valid connection from the connection pool. +// When this number is exceeded, a new connection is created directly. +const maxBadConnRetries = 2 + var ( // ErrClosed performs any operation on the closed client will return this error. ErrClosed = errors.New("redis: client is closed") @@ -235,7 +239,7 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } - for { + for i := 0; i < maxBadConnRetries; i++ { p.connsMu.Lock() cn := p.popIdle() p.connsMu.Unlock() @@ -253,6 +257,10 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return cn, nil } + // After the connection pool is empty or after trying maxBadConnRetries times, + // we still get a broken connection. + // At this time, we directly create a new connection. + atomic.AddUint32(&p.stats.Misses, 1) newcn, err := p.newConn(ctx, true) diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 6c94fc27a..0e33781c9 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -2,6 +2,7 @@ package pool_test import ( "context" + "net" "sync" "testing" "time" @@ -27,7 +28,7 @@ var _ = Describe("ConnPool", func() { }) AfterEach(func() { - connPool.Close() + _ = connPool.Close() }) It("should unblock client when conn is removed", func() { @@ -81,6 +82,56 @@ var _ = Describe("ConnPool", func() { }) }) +var _ = Describe("bad conn", func() { + ctx := context.Background() + var opt *pool.Options + var connPool *pool.ConnPool + + BeforeEach(func() { + opt = &pool.Options{ + Dialer: dummyDialer, + PoolSize: 10, + PoolTimeout: time.Hour, + IdleTimeout: time.Minute, + IdleCheckFrequency: 10 * time.Second, + } + connPool = pool.NewConnPool(opt) + }) + + AfterEach(func() { + _ = connPool.Close() + }) + + It("should maxBadConnRetries", func() { + var err error + conns := make([]*pool.Conn, opt.PoolSize) + for i := 0; i < opt.PoolSize; i++ { + conns[i], err = connPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + } + for i := 0; i < opt.PoolSize; i++ { + // Damage it. + _ = conns[i].Close() + connPool.Put(ctx, conns[i]) + } + + var newConn *net.TCPConn + opt.Dialer = func(ctx context.Context) (net.Conn, error) { + newConn = &net.TCPConn{} + return newConn, nil + } + + conn, err := connPool.Get(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(conn.NetConn()).To(Equal(newConn)) + + stats := connPool.Stats() + Expect(stats.IdleConns).To(Equal(uint32(opt.PoolSize - pool.MaxBadConnRetries()))) + Expect(stats.Misses).To(Equal(uint32(opt.PoolSize + 1))) + Expect(stats.Hits).To(Equal(uint32(0))) + }) +}) + var _ = Describe("MinIdleConns", func() { const poolSize = 100 ctx := context.Background()