diff --git a/sentinel.go b/sentinel.go index a132af2fe..063866353 100644 --- a/sentinel.go +++ b/sentinel.go @@ -566,29 +566,60 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { } } - for i, sentinelAddr := range c.sentinelAddrs { - sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr)) + var ( + masterAddr string + wg sync.WaitGroup + once sync.Once + errCh = make(chan error, len(c.sentinelAddrs)) + ) - masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result() - if err != nil { - _ = sentinel.Close() - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return "", err - } - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName master=%q failed: %s", - c.opt.MasterName, err) - continue - } + ctx, cancel := context.WithCancel(ctx) + defer cancel() - // Push working sentinel to the top. - c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] - c.setSentinel(ctx, sentinel) + for i, sentinelAddr := range c.sentinelAddrs { + wg.Add(1) + go func(i int, addr string) { + defer wg.Done() + sentinelCli := NewSentinelClient(c.opt.sentinelOptions(addr)) + addrVal, err := sentinelCli.GetMasterAddrByName(ctx, c.opt.MasterName).Result() + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + // Report immediately and return + errCh <- err + return + } + internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", + addr, c.opt.MasterName, err) + _ = sentinelCli.Close() + return + } - addr := net.JoinHostPort(masterAddr[0], masterAddr[1]) - return addr, nil + once.Do(func() { + masterAddr = net.JoinHostPort(addrVal[0], addrVal[1]) + // Push working sentinel to the top + c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] + c.setSentinel(ctx, sentinelCli) + internal.Logger.Printf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) + cancel() + }) + }(i, sentinelAddr) } - return "", errors.New("redis: all sentinels specified in configuration are unreachable") + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + if masterAddr != "" { + return masterAddr, nil + } + return "", errors.New("redis: all sentinels specified in configuration are unreachable") + case err := <-errCh: + return "", err + } } func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected bool) ([]string, error) { diff --git a/sentinel_test.go b/sentinel_test.go index 07c7628a0..cde7f956d 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "net" + "time" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" @@ -32,6 +33,24 @@ var _ = Describe("Sentinel PROTO 2", func() { }) }) +var _ = Describe("Sentinel resolution", func() { + It("should resolve master without context exhaustion", func() { + shortCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: sentinelName, + SentinelAddrs: sentinelAddrs, + MaxRetries: -1, + }) + + err := client.Ping(shortCtx).Err() + Expect(err).NotTo(HaveOccurred(), "expected master to resolve without context exhaustion") + + _ = client.Close() + }) +}) + var _ = Describe("Sentinel", func() { var client *redis.Client var master *redis.Client