Skip to content

Commit 5228611

Browse files
committed
fix(auth): streamline auth err proccess
1 parent a6a2c9d commit 5228611

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

redis.go

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"log"
87
"net"
98
"sync"
109
"sync/atomic"
@@ -285,21 +284,22 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
285284
return cn, nil
286285
}
287286

288-
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, cn *pool.Conn) auth.CredentialsListener {
289-
connPool := pool.NewSingleConnPool(c.connPool, cn)
290-
// hooksMixin are intentionally empty here
291-
conn := newConn(c.opt, connPool, nil)
292-
ctx = c.context(ctx)
287+
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
293288
return auth.NewReAuthCredentialsListener(
294-
c.reAuthConnection(ctx, conn),
295-
c.onAuthenticationErr(ctx, conn),
289+
c.reAuthConnection(poolCn),
290+
c.onAuthenticationErr(poolCn),
296291
)
297292
}
298293

299-
func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error {
294+
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
300295
return func(credentials auth.Credentials) error {
301296
var err error
302297
username, password := credentials.BasicAuth()
298+
ctx := context.Background()
299+
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
300+
// hooksMixin are intentionally empty here
301+
cn := newConn(c.opt, connPool, nil)
302+
303303
if username != "" {
304304
err = cn.AuthACL(ctx, username, password).Err()
305305
} else {
@@ -308,22 +308,13 @@ func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(creden
308308
return err
309309
}
310310
}
311-
func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) {
311+
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
312312
return func(err error) {
313-
// since the connection pool of the *Conn will actually return us the underlying pool.Conn,
314-
// we can get it from the *Conn and remove it from the clients pool.
315313
if err != nil {
316314
if isBadConn(err, false, c.opt.Addr) {
317-
poolCn, getErr := cn.connPool.Get(ctx)
318-
if getErr == nil {
319-
c.connPool.Remove(ctx, poolCn, err)
320-
} else {
321-
// if we can't get the pool connection, we can only close the connection
322-
if err := cn.Close(); err != nil {
323-
log.Printf("failed to close connection: %v", err)
324-
}
325-
}
315+
c.connPool.CloseConn(poolCn)
326316
}
317+
internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
327318
}
328319
}
329320
}
@@ -368,7 +359,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
368359
username, password := "", ""
369360
if c.opt.StreamingCredentialsProvider != nil {
370361
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
371-
Subscribe(c.newReAuthCredentialsListener(ctx, cn))
362+
Subscribe(c.newReAuthCredentialsListener(cn))
372363
if err != nil {
373364
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
374365
}
@@ -401,7 +392,11 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
401392
return err
402393
} else if password != "" {
403394
// Try legacy AUTH command if HELLO failed
404-
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
395+
if username != "" {
396+
err = conn.AuthACL(ctx, username, password).Err()
397+
} else {
398+
err = conn.Auth(ctx, password).Err()
399+
}
405400
if err != nil {
406401
return fmt.Errorf("failed to authenticate: %w", err)
407402
}

0 commit comments

Comments
 (0)