Skip to content

Commit 7181d41

Browse files
committed
use ErrBadConn only if nothing is written
1 parent e2efca3 commit 7181d41

File tree

5 files changed

+38
-3
lines changed

5 files changed

+38
-3
lines changed

connection.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
107107
if err == nil {
108108
return &mysqlTx{mc}, err
109109
}
110-
110+
if err == errNoWrite {
111+
return driver.ErrBadConn
112+
}
111113
return nil, err
112114
}
113115

@@ -137,6 +139,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
137139
// Send command
138140
err := mc.writeCommandPacketStr(comStmtPrepare, query)
139141
if err != nil {
142+
if err == errNoWrite {
143+
return nil, driver.ErrBadConn
144+
}
140145
return nil, err
141146
}
142147

@@ -177,6 +182,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
177182
insertId: int64(mc.insertId),
178183
}, err
179184
}
185+
if err == errNoWrite {
186+
return nil, driver.ErrBadConn
187+
}
180188
return nil, err
181189
}
182190

@@ -231,6 +239,9 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
231239
return rows, err
232240
}
233241
}
242+
if err == errNoWrite {
243+
return nil, driver.ErrBadConn
244+
}
234245
return nil, err
235246
}
236247

@@ -243,6 +254,9 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
243254
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
244255
// Send command
245256
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
257+
if err == errNoWrite {
258+
return nil, driver.ErrBadConn
259+
}
246260
return nil, err
247261
}
248262

driver.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
118118
maxap, err := mc.getSystemVar("max_allowed_packet")
119119
if err != nil {
120120
mc.Close()
121+
if err == driver.ErrBadConn {
122+
err = ErrInvalidConn
123+
}
121124
return nil, err
122125
}
123126
mc.maxPacketAllowed = stringToInt(maxap) - 1

errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ var (
2828
ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
2929
ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
3030
ErrBusyBuffer = errors.New("Busy buffer")
31+
errNoWrite = errors.New("No data written")
3132
)
3233

3334
var errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

packets.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,12 @@ func (mc *mysqlConn) writePacket(data []byte) error {
114114
// Handle error
115115
if err == nil { // n != len(data)
116116
errLog.Print(ErrMalformPkt)
117-
} else {
118-
errLog.Print(err)
117+
return ErrMalformPkt
118+
}
119+
errLog.Print(err)
120+
if n == 0 && pktLen == len(data)-4 {
121+
// first loop iteration, nothing was written
122+
return errNoWrite
119123
}
120124
return ErrInvalidConn
121125
}
@@ -731,6 +735,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
731735
data = data[pktLen-dataOffset:]
732736
continue
733737
}
738+
if err == errNoWrite && argLen != len(arg) {
739+
// must not relay errNoWrite after first loop iteration
740+
return ErrInvalidConn
741+
}
734742
return err
735743

736744
}

statement.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ func (stmt *mysqlStmt) Close() error {
2626
}
2727

2828
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
29+
if err == errNoWrite {
30+
err = ErrInvalidConn
31+
}
2932
stmt.mc = nil
3033
return err
3134
}
@@ -42,6 +45,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
4245
// Send command
4346
err := stmt.writeExecutePacket(args)
4447
if err != nil {
48+
if err == errNoWrite {
49+
return driver.ErrBadConn
50+
}
4551
return nil, err
4652
}
4753

@@ -82,6 +88,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
8288
// Send command
8389
err := stmt.writeExecutePacket(args)
8490
if err != nil {
91+
if err == errNoWrite {
92+
return driver.ErrBadConn
93+
}
8594
return nil, err
8695
}
8796

0 commit comments

Comments
 (0)