Skip to content

Commit 4acde0d

Browse files
committed
Exec() now provides access to status of multiple statements.
It now reports the last inserted ID and affected row count for all statements, not just the last one. This is useful to execute batches of statements such as UPDATE with minimal roundtrips. The approach taken is to track last insert id and affected rows using []int64 instead of a int64. Both are set in `mysqlResult`, and a new `mysql.Result` interface makes them accessible to callers calling `Exec()` via `sql.Conn.Raw`. For example: ``` conn.Raw(func(conn interface{}) error { ex := conn.(driver.Execer) res, err := ex.Exec(` UPDATE point SET x = 1 WHERE y = 2; UPDATE point SET x = 2 WHERE y = 3; `, nil) // Both slices have 2 elements. log.Print(res.(mysql.Result).AllRowsAffected()) log.Print(res.(mysql.Result).AllLastInsertIds()) }) ```
1 parent 217d050 commit 4acde0d

9 files changed

+259
-50
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ Allow multiple statements in one query. While this allows batch queries, it also
288288

289289
When `multiStatements` is used, `?` parameters must only be used in the first statement.
290290

291+
It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example:
292+
293+
```
294+
conn, _ := db.Conn(ctx)
295+
conn.Raw(func(conn interface{}) error {
296+
ex := conn.(driver.Execer)
297+
res, err := ex.Exec(`
298+
UPDATE point SET x = 1 WHERE y = 2;
299+
UPDATE point SET x = 2 WHERE y = 3;
300+
`, nil)
301+
// Both slices have 2 elements.
302+
log.Print(res.(mysql.Result).AllRowsAffected())
303+
log.Print(res.(mysql.Result).AllLastInsertIds())
304+
})
305+
```
306+
291307
##### `parseTime`
292308

293309
```

auth.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
347347
case 1:
348348
switch authData[0] {
349349
case cachingSha2PasswordFastAuthSuccess:
350-
if err = mc.readResultOK(); err == nil {
350+
if err = mc.resultUnchanged().readResultOK(); err == nil {
351351
return nil // auth successful
352352
}
353353

@@ -398,7 +398,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
398398
return err
399399
}
400400
}
401-
return mc.readResultOK()
401+
return mc.resultUnchanged().readResultOK()
402402

403403
default:
404404
return ErrMalformPkt
@@ -427,7 +427,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
427427
if err != nil {
428428
return err
429429
}
430-
return mc.readResultOK()
430+
return mc.resultUnchanged().readResultOK()
431431
}
432432

433433
default:

connection.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import (
2323
type mysqlConn struct {
2424
buf buffer
2525
netConn net.Conn
26-
rawConn net.Conn // underlying connection when netConn is TLS connection.
27-
affectedRows uint64
28-
insertId uint64
26+
rawConn net.Conn // underlying connection when netConn is TLS connection.
27+
result mysqlResult // managed by clearResult() and handleOkPacket().
2928
cfg *Config
3029
maxAllowedPacket int
3130
maxWriteSize int
@@ -149,6 +148,7 @@ func (mc *mysqlConn) cleanup() {
149148
if err := mc.netConn.Close(); err != nil {
150149
errLog.Print(err)
151150
}
151+
mc.clearResult()
152152
}
153153

154154
func (mc *mysqlConn) error() error {
@@ -310,28 +310,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
310310
}
311311
query = prepared
312312
}
313-
mc.affectedRows = 0
314-
mc.insertId = 0
315313

316314
err := mc.exec(query)
317315
if err == nil {
318-
return &mysqlResult{
319-
affectedRows: int64(mc.affectedRows),
320-
insertId: int64(mc.insertId),
321-
}, err
316+
copied := mc.result
317+
return &copied, err
322318
}
323319
return nil, mc.markBadConn(err)
324320
}
325321

326322
// Internal function to execute commands
327323
func (mc *mysqlConn) exec(query string) error {
324+
handleOk := mc.clearResult()
328325
// Send command
329326
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
330327
return mc.markBadConn(err)
331328
}
332329

333330
// Read Result
334-
resLen, err := mc.readResultSetHeaderPacket()
331+
resLen, err := handleOk.readResultSetHeaderPacket()
335332
if err != nil {
336333
return err
337334
}
@@ -348,14 +345,16 @@ func (mc *mysqlConn) exec(query string) error {
348345
}
349346
}
350347

351-
return mc.discardResults()
348+
return handleOk.discardResults()
352349
}
353350

354351
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
355352
return mc.query(query, args)
356353
}
357354

358355
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
356+
handleOk := mc.clearResult()
357+
359358
if mc.closed.IsSet() {
360359
errLog.Print(ErrInvalidConn)
361360
return nil, driver.ErrBadConn
@@ -376,7 +375,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
376375
if err == nil {
377376
// Read Result
378377
var resLen int
379-
resLen, err = mc.readResultSetHeaderPacket()
378+
resLen, err = handleOk.readResultSetHeaderPacket()
380379
if err == nil {
381380
rows := new(textRows)
382381
rows.mc = mc
@@ -404,12 +403,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
404403
// The returned byte slice is only valid until the next read
405404
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
406405
// Send command
406+
handleOk := mc.clearResult()
407407
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
408408
return nil, err
409409
}
410410

411411
// Read Result
412-
resLen, err := mc.readResultSetHeaderPacket()
412+
resLen, err := handleOk.readResultSetHeaderPacket()
413413
if err == nil {
414414
rows := new(textRows)
415415
rows.mc = mc
@@ -460,11 +460,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
460460
}
461461
defer mc.finish()
462462

463+
handleOk := mc.clearResult()
463464
if err = mc.writeCommandPacket(comPing); err != nil {
464465
return mc.markBadConn(err)
465466
}
466467

467-
return mc.readResultOK()
468+
return handleOk.readResultOK()
468469
}
469470

470471
// BeginTx implements driver.ConnBeginTx interface

driver_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,11 +2155,51 @@ func TestRejectReadOnly(t *testing.T) {
21552155
}
21562156

21572157
func TestPing(t *testing.T) {
2158+
ctx := context.Background()
21582159
runTests(t, dsn, func(dbt *DBTest) {
21592160
if err := dbt.db.Ping(); err != nil {
21602161
dbt.fail("Ping", "Ping", err)
21612162
}
21622163
})
2164+
2165+
runTests(t, dsn, func(dbt *DBTest) {
2166+
conn, err := dbt.db.Conn(ctx)
2167+
if err != nil {
2168+
dbt.fail("db", "Conn", err)
2169+
}
2170+
2171+
// Check that affectedRows and insertIds are cleared after each call.
2172+
conn.Raw(func(conn interface{}) error {
2173+
c := conn.(*mysqlConn)
2174+
2175+
// Issue a query that sets affectedRows and insertIds.
2176+
q, err := c.Query(`SELECT 1`, nil)
2177+
if err != nil {
2178+
dbt.fail("Conn", "Query", err)
2179+
}
2180+
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
2181+
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
2182+
}
2183+
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
2184+
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
2185+
}
2186+
q.Close()
2187+
2188+
// Verify that Ping() clears both fields.
2189+
for i := 0; i < 2; i++ {
2190+
if err := c.Ping(ctx); err != nil {
2191+
dbt.fail("Pinger", "Ping", err)
2192+
}
2193+
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
2194+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2195+
}
2196+
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
2197+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2198+
}
2199+
}
2200+
return nil
2201+
})
2202+
})
21632203
}
21642204

21652205
// See Issue #799
@@ -2379,6 +2419,42 @@ func TestMultiResultSetNoSelect(t *testing.T) {
23792419
})
23802420
}
23812421

2422+
func TestExecMultipleResults(t *testing.T) {
2423+
ctx := context.Background()
2424+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2425+
dbt.mustExec(`
2426+
CREATE TABLE test (
2427+
id INT NOT NULL AUTO_INCREMENT,
2428+
value VARCHAR(255),
2429+
PRIMARY KEY (id)
2430+
)`)
2431+
conn, err := dbt.db.Conn(ctx)
2432+
if err != nil {
2433+
t.Fatalf("failed to connect: %v", err)
2434+
}
2435+
conn.Raw(func(conn interface{}) error {
2436+
ex := conn.(driver.Execer)
2437+
res, err := ex.Exec(`
2438+
INSERT INTO test (value) VALUES ('a'), ('b');
2439+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2440+
`, nil)
2441+
if err != nil {
2442+
t.Fatalf("insert statements failed: %v", err)
2443+
}
2444+
mres := res.(Result)
2445+
if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) {
2446+
t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want)
2447+
}
2448+
// For INSERTs containing multiple rows, LAST_INSERT_ID() returns the
2449+
// first inserted ID, not the last.
2450+
if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) {
2451+
t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want)
2452+
}
2453+
return nil
2454+
})
2455+
})
2456+
}
2457+
23822458
// tests if rows are set in a proper state if some results were ignored before
23832459
// calling rows.NextResultSet.
23842460
func TestSkipResults(t *testing.T) {
@@ -2400,6 +2476,42 @@ func TestSkipResults(t *testing.T) {
24002476
})
24012477
}
24022478

2479+
func TestQueryMultipleResults(t *testing.T) {
2480+
ctx := context.Background()
2481+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
2482+
dbt.mustExec(`
2483+
CREATE TABLE test (
2484+
id INT NOT NULL AUTO_INCREMENT,
2485+
value VARCHAR(255),
2486+
PRIMARY KEY (id)
2487+
)`)
2488+
conn, err := dbt.db.Conn(ctx)
2489+
if err != nil {
2490+
t.Fatalf("failed to connect: %v", err)
2491+
}
2492+
conn.Raw(func(conn interface{}) error {
2493+
qr := conn.(driver.Queryer)
2494+
2495+
c := conn.(*mysqlConn)
2496+
2497+
// Demonstrate that repeated queries reset the affectedRows
2498+
for i := 0; i < 2; i++ {
2499+
_, err := qr.Query(`
2500+
INSERT INTO test (value) VALUES ('a'), ('b');
2501+
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
2502+
`, nil)
2503+
if err != nil {
2504+
t.Fatalf("insert statements failed: %v", err)
2505+
}
2506+
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
2507+
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
2508+
}
2509+
}
2510+
return nil
2511+
})
2512+
})
2513+
}
2514+
24032515
func TestPingContext(t *testing.T) {
24042516
runTests(t, dsn, func(dbt *DBTest) {
24052517
ctx, cancel := context.WithCancel(context.Background())

infile.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
9393
}
9494
}
9595

96-
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
96+
func (mc *okHandler) handleInFileRequest(name string) (err error) {
9797
var rdr io.Reader
9898
var data []byte
9999
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
@@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
154154
for err == nil {
155155
n, err = rdr.Read(data[4:])
156156
if n > 0 {
157-
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
157+
if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil {
158158
return ioErr
159159
}
160160
}
@@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
168168
if data == nil {
169169
data = make([]byte, 4)
170170
}
171-
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
171+
if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil {
172172
return ioErr
173173
}
174174

@@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
177177
return mc.readResultOK()
178178
}
179179

180-
mc.readPacket()
180+
mc.conn().readPacket()
181181
return err
182182
}

0 commit comments

Comments
 (0)