diff --git a/AUTHORS b/AUTHORS index fbe4ec442..3f322e802 100644 --- a/AUTHORS +++ b/AUTHORS @@ -35,6 +35,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr diff --git a/connection.go b/connection.go index f74235519..12a382268 100644 --- a/connection.go +++ b/connection.go @@ -475,7 +475,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { defer mc.finish() if err = mc.writeCommandPacket(comPing); err != nil { - return + return mc.markBadConn(err) } return mc.readResultOK() diff --git a/connection_test.go b/connection_test.go index 352c54ed7..2a1c8e888 100644 --- a/connection_test.go +++ b/connection_test.go @@ -11,6 +11,8 @@ package mysql import ( "context" "database/sql/driver" + "errors" + "net" "testing" ) @@ -108,3 +110,48 @@ func TestCleanCancel(t *testing.T) { } } } + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +}