Skip to content

Commit c5054e2

Browse files
committed
feat(conn): add close hook on conn
1 parent 8f05aef commit c5054e2

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

internal/pool/conn.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ type Conn struct {
2323
Inited bool
2424
pooled bool
2525
createdAt time.Time
26+
27+
onClose func() error
2628
}
2729

2830
func NewConn(netConn net.Conn) *Conn {
@@ -46,6 +48,10 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
4648
atomic.StoreInt64(&cn.usedAt, tm.Unix())
4749
}
4850

51+
func (cn *Conn) SetOnClose(fn func() error) {
52+
cn.onClose = fn
53+
}
54+
4955
func (cn *Conn) SetNetConn(netConn net.Conn) {
5056
cn.netConn = netConn
5157
cn.rd.Reset(netConn)
@@ -95,6 +101,10 @@ func (cn *Conn) WithWriter(
95101
}
96102

97103
func (cn *Conn) Close() error {
104+
if cn.onClose != nil {
105+
// ignore error
106+
_ = cn.onClose()
107+
}
98108
return cn.netConn.Close()
99109
}
100110

redis.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,14 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
285285
return cn, nil
286286
}
287287

288-
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener {
288+
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, cn *pool.Conn) auth.CredentialsListener {
289+
connPool := pool.NewSingleConnPool(c.connPool, cn)
290+
// hooksMixin are intentionally empty here
291+
conn := newConn(c.opt, connPool, nil)
292+
ctx = c.context(ctx)
289293
return auth.NewReAuthCredentialsListener(
290-
c.reAuthConnection(c.context(ctx), conn),
291-
c.onAuthenticationErr(c.context(ctx), conn),
294+
c.reAuthConnection(ctx, conn),
295+
c.onAuthenticationErr(ctx, conn),
292296
)
293297
}
294298

@@ -313,11 +317,11 @@ func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err
313317
poolCn, getErr := cn.connPool.Get(ctx)
314318
if getErr == nil {
315319
c.connPool.Remove(ctx, poolCn, err)
316-
} else {
317-
// if we can't get the pool connection, we can only close the connection
318-
if err := cn.Close(); err != nil {
319-
log.Printf("failed to close connection: %v", err)
320-
}
320+
}
321+
322+
// if we can't get the pool connection, we can only close the connection
323+
if err := cn.Close(); err != nil {
324+
log.Printf("failed to close connection: %v", err)
321325
}
322326
}
323327
}
@@ -353,8 +357,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
353357
var err error
354358
cn.Inited = true
355359
connPool := pool.NewSingleConnPool(c.connPool, cn)
356-
357-
conn := newConn(c.opt, connPool, c.hooksMixin)
360+
conn := newConn(c.opt, connPool, &c.hooksMixin)
358361

359362
protocol := c.opt.Protocol
360363
// By default, use RESP3 in current version.
@@ -364,12 +367,13 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
364367

365368
username, password := "", ""
366369
if c.opt.StreamingCredentialsProvider != nil {
367-
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
368-
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
370+
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
371+
Subscribe(c.newReAuthCredentialsListener(ctx, cn))
369372
if err != nil {
370373
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
371374
}
372-
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
375+
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
376+
cn.SetOnClose(unsubscribeFromCredentialsProvider)
373377
username, password = credentials.BasicAuth()
374378
} else if c.opt.CredentialsProviderContext != nil {
375379
username, password, err = c.opt.CredentialsProviderContext(ctx)
@@ -770,7 +774,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client {
770774
}
771775

772776
func (c *Client) Conn() *Conn {
773-
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), c.hooksMixin)
777+
return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
774778
}
775779

776780
// Do create a Cmd from the args and processes the cmd.
@@ -908,15 +912,18 @@ type Conn struct {
908912
// newConn is a helper func to create a new Conn instance.
909913
// the Conn instance is not thread-safe and should not be shared between goroutines.
910914
// the parentHooks will be cloned, no need to clone before passing it.
911-
func newConn(opt *Options, connPool pool.Pooler, parentHooks hooksMixin) *Conn {
915+
func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
912916
c := Conn{
913917
baseClient: baseClient{
914-
opt: opt,
915-
connPool: connPool,
916-
hooksMixin: parentHooks.clone(),
918+
opt: opt,
919+
connPool: connPool,
917920
},
918921
}
919922

923+
if parentHooks != nil {
924+
c.hooksMixin = parentHooks.clone()
925+
}
926+
920927
c.cmdable = c.Process
921928
c.statefulCmdable = c.Process
922929
c.initHooks(hooks{

0 commit comments

Comments
 (0)