diff --git a/errors.go b/errors.go index 20003e086..09b3ef14e 100644 --- a/errors.go +++ b/errors.go @@ -17,6 +17,7 @@ import ( ) var ( + errInvalidConn = errors.New("Invalid Connection") errMalformPkt = errors.New("Malformed Packet") errNoTLS = errors.New("TLS encryption requested but server does not support TLS") errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords") diff --git a/rows.go b/rows.go index fda75998b..4fef1d9d5 100644 --- a/rows.go +++ b/rows.go @@ -11,7 +11,6 @@ package mysql import ( "database/sql/driver" - "errors" "io" ) @@ -37,33 +36,30 @@ func (rows *mysqlRows) Columns() (columns []string) { } func (rows *mysqlRows) Close() (err error) { - defer func() { - rows.mc = nil - }() - // Remove unread packets from stream if !rows.eof { if rows.mc == nil { - return errors.New("Invalid Connection") + return errInvalidConn } err = rows.mc.readUntilEOF() } + rows.mc = nil + return } -func (rows *mysqlRows) Next(dest []driver.Value) error { +func (rows *mysqlRows) Next(dest []driver.Value) (err error) { if rows.eof { return io.EOF } if rows.mc == nil { - return errors.New("Invalid Connection") + return errInvalidConn } // Fetch next row from stream - var err error if rows.binary { err = rows.readBinaryRow(dest) } else {