Skip to content

feat: add lock for hooks.dial #2460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ type ClusterClient struct {
state *clusterStateHolder
cmdsInfoCache *cmdsInfoCache
cmdable
hooksMixin
*hooksMixin
}

// NewClusterClient returns a Redis Cluster client as described in
Expand All @@ -847,8 +847,9 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt.init()

c := &ClusterClient{
opt: opt,
nodes: newClusterNodes(opt),
opt: opt,
nodes: newClusterNodes(opt),
hooksMixin: &hooksMixin{},
}

c.state = newClusterStateHolder(c.loadState)
Expand Down
30 changes: 25 additions & 5 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -44,6 +45,8 @@ type hooksMixin struct {
slice []Hook
initial hooks
current hooks

hooksMu sync.RWMutex
}

func (hs *hooksMixin) initHooks(hooks hooks) {
Expand Down Expand Up @@ -117,6 +120,9 @@ func (hs *hooksMixin) AddHook(hook Hook) {
func (hs *hooksMixin) chain() {
hs.initial.setDefaults()

hs.hooksMu.Lock()
defer hs.hooksMu.Unlock()

hs.current.dial = hs.initial.dial
hs.current.process = hs.initial.process
hs.current.pipeline = hs.initial.pipeline
Expand All @@ -138,8 +144,15 @@ func (hs *hooksMixin) chain() {
}
}

func (hs *hooksMixin) clone() hooksMixin {
clone := *hs
func (hs *hooksMixin) clone() *hooksMixin {
hs.hooksMu.Lock()
defer hs.hooksMu.Unlock()

clone := &hooksMixin{
slice: hs.slice,
initial: hs.initial,
current: hs.current,
}
l := len(clone.slice)
clone.slice = clone.slice[:l:l]
return clone
Expand All @@ -166,7 +179,11 @@ func (hs *hooksMixin) withProcessPipelineHook(
}

func (hs *hooksMixin) dialHook(ctx context.Context, network, addr string) (net.Conn, error) {
return hs.current.dial(ctx, network, addr)
hs.hooksMu.RLock()
conn, err := hs.current.dial(ctx, network, addr)
hs.hooksMu.RUnlock()

return conn, err
}

func (hs *hooksMixin) processHook(ctx context.Context, cmd Cmder) error {
Expand Down Expand Up @@ -588,8 +605,8 @@ func (c *baseClient) context(ctx context.Context) context.Context {
// of idle connections. You can control the pool size with Config.PoolSize option.
type Client struct {
*baseClient
*hooksMixin
cmdable
hooksMixin
}

// NewClient returns a client to the Redis Server specified by Options.
Expand All @@ -600,6 +617,7 @@ func NewClient(opt *Options) *Client {
baseClient: &baseClient{
opt: opt,
},
hooksMixin: &hooksMixin{},
}
c.init()
c.connPool = newConnPool(opt, c.dialHook)
Expand All @@ -620,6 +638,7 @@ func (c *Client) init() {
func (c *Client) WithTimeout(timeout time.Duration) *Client {
clone := *c
clone.baseClient = c.baseClient.withTimeout(timeout)
clone.hooksMixin = c.hooksMixin.clone()
clone.init()
return &clone
}
Expand Down Expand Up @@ -758,7 +777,7 @@ type Conn struct {
baseClient
cmdable
statefulCmdable
hooksMixin
*hooksMixin
}

func newConn(opt *Options, connPool pool.Pooler) *Conn {
Expand All @@ -767,6 +786,7 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
opt: opt,
connPool: connPool,
},
hooksMixin: &hooksMixin{},
}

c.cmdable = c.Process
Expand Down
3 changes: 2 additions & 1 deletion ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ func (c *ringSharding) Close() error {
// Otherwise you should use Redis Cluster.
type Ring struct {
cmdable
hooksMixin
*hooksMixin

opt *RingOptions
sharding *ringSharding
Expand All @@ -504,6 +504,7 @@ func NewRing(opt *RingOptions) *Ring {
opt: opt,
sharding: newRingSharding(opt),
heartbeatCancelFn: hbCancel,
hooksMixin: &hooksMixin{},
}

ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
Expand Down
4 changes: 3 additions & 1 deletion sentinel.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
baseClient: &baseClient{
opt: opt,
},
hooksMixin: &hooksMixin{},
}
rdb.init()

Expand Down Expand Up @@ -267,7 +268,7 @@ func masterReplicaDialer(
// SentinelClient is a client for a Redis Sentinel.
type SentinelClient struct {
*baseClient
hooksMixin
*hooksMixin
}

func NewSentinelClient(opt *Options) *SentinelClient {
Expand All @@ -276,6 +277,7 @@ func NewSentinelClient(opt *Options) *SentinelClient {
baseClient: &baseClient{
opt: opt,
},
hooksMixin: &hooksMixin{},
}

c.initHooks(hooks{
Expand Down
2 changes: 1 addition & 1 deletion tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Tx struct {
baseClient
cmdable
statefulCmdable
hooksMixin
*hooksMixin
}

func (c *Client) newTx() *Tx {
Expand Down