diff --git a/connection.go b/connection.go index e57061412..4e779089a 100644 --- a/connection.go +++ b/connection.go @@ -48,6 +48,27 @@ type mysqlConn struct { finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled closed atomicBool // set when conn is closed, before closech is closed + + // for killing query after timeout + id int + d MySQLDriver +} + +func (mc *mysqlConn) kill() error { + t := 50 * time.Millisecond + killCfg := *mc.cfg + killCfg.Timeout = t + killCfg.ReadTimeout = t + killCfg.WriteTimeout = t + + conn, err := mc.d.Open(killCfg.FormatDSN()) + if err != nil { + return err + } + defer conn.Close() + query := "KILL QUERY " + strconv.Itoa(mc.id) + _, err = conn.(*mysqlConn).Exec(query, []driver.Value{}) + return err } // Handles parameters set in DSN after the connection is established @@ -445,6 +466,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // finish is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) + if mc.cfg.KillQueryOnTimeout { + // do not put kill to cleanup to prevent cyclic kills + mc.kill() + } mc.cleanup() } diff --git a/driver.go b/driver.go index 27cf5ad4e..f2035edb6 100644 --- a/driver.go +++ b/driver.go @@ -150,6 +150,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } + mc.d = d return mc, nil } diff --git a/driver_go18_test.go b/driver_go18_test.go index afd5694ec..df29ab5fe 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -17,10 +17,18 @@ import ( "fmt" "math" "reflect" + "strings" "testing" "time" ) +// timeouts for cancelling queries +var ( + killTimeout = 5 * time.Second + getIdTimeout = 5 * time.Second + pollTimeout = 100 * time.Millisecond +) + // static interface implementation checks of mysqlConn var ( _ driver.ConnBeginTx = &mysqlConn{} @@ -52,6 +60,199 @@ var ( _ driver.RowsNextResultSet = &textRows{} ) +// dbProcess represents information about process from 'show processlist' command +type dbProcess struct { + ID int + User string + Host string + DB string + Info string +} + +func getProcesslist(dbName string, db *sql.DB) ([]dbProcess, error) { + rows, err := db.Query(`SELECT id, user, host, db, IFNULL(info, '') FROM information_schema.processlist WHERE db = ?`, dbName) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []dbProcess + for rows.Next() { + p := dbProcess{} + if err := rows.Scan(&p.ID, &p.User, &p.Host, &p.DB, &p.Info); err != nil { + return nil, err + } + result = append(result, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +func checkProcessExists(dbName string, id int, db *sql.DB) bool { + exists := false + db.QueryRow("SELECT count(*) > 0 FROM information_schema.processlist WHERE db = ? AND id = ?", dbName, id).Scan(&exists) + return exists +} + +func getQueryProcess(db *sql.DB, dbName, query string) (*dbProcess, error) { + end := time.Now().Add(getIdTimeout) + var ( + err error + processList []dbProcess + longProcess *dbProcess + ) + for time.Now().Before(end) { + processList, err = getProcesslist(dbName, db) + + longProcess = nil + for _, p := range processList { + if len(p.Info) > 0 && strings.HasPrefix(query, p.Info) { + longProcess = &p + break + } + } + + if err != nil || longProcess == nil { + time.Sleep(pollTimeout) + } else { + break + } + } + if longProcess == nil { + return nil, fmt.Errorf("process for query \"%s\" not found", query) + } + return longProcess, err +} + +var expectedKilledErr = fmt.Errorf("process expected to be killed") + +func killQuery(db *sql.DB, dbName, query string, timeout time.Duration, cancel context.CancelFunc) error { + process, err := getQueryProcess(db, dbName, query) + if err != nil { + return fmt.Errorf("failed to get mysql process: %v", err) + } + cancel() + + end := time.Now().Add(timeout) + for time.Now().Before(end) { + if checkProcessExists(dbName, process.ID, db) { + err = expectedKilledErr + time.Sleep(pollTimeout) + } else { + err = nil + break + } + } + return err +} + +func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, query string, queryFunc func() error) { + tx, err := dbt.db.BeginTx(context.Background(), nil) + if err != nil { + dbt.Fatal(err) + return + } + + _, err = tx.Exec("LOCK TABLES test WRITE") + if err != nil { + tx.Rollback() + dbt.Fatal(err) + } + + errChan := make(chan error) + go func() { + // This query will be canceled. + err = queryFunc() + if err != nil && err != context.Canceled { + errLog.Print(err) + } + if err != context.Canceled && ctx.Err() != context.Canceled { + errChan <- fmt.Errorf("expected context.Canceled, got %v", err) + } + errChan <- nil + }() + + // it is safe to not use timeouts here since they are inside the killQuery function + err = killQuery(dbt.db, dbname, query, killTimeout, cancel) + if err != nil { + dbt.Error(err) + return + } + + // it is safe to block here since if reached this line then + // query has already been killed + err = <-errChan + if err != nil { + dbt.Error(err) + return + } + + _, err = tx.Exec("UNLOCK TABLES") + if err != nil { + tx.Rollback() + dbt.Fatal(err) + } + tx.Commit() +} + +func testCancelNoKill(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, query string, queryFunc func() error) { + tx, err := dbt.db.BeginTx(context.Background(), nil) + if err != nil { + dbt.Fatal(err) + return + } + + _, err = tx.Exec("LOCK TABLES test WRITE") + if err != nil { + tx.Rollback() + dbt.Fatal(err) + } + + errChan := make(chan error) + go func() { + // This query will be canceled. + err = queryFunc() + if err != nil && err != context.Canceled { + errLog.Print(err) + } + if err != context.Canceled && ctx.Err() != context.Canceled { + errChan <- fmt.Errorf("expected context.Canceled, got %v", err) + return + } + errChan <- nil + }() + + // it is safe to not use timeouts here since they are inside the killQuery function + err = killQuery(dbt.db, dbname, query, 500*time.Millisecond, cancel) + if err != expectedKilledErr { + if err == nil { + dbt.Errorf("query kill expected to fail") + } else { + dbt.Errorf(fmt.Sprintf("unexpected error %s", err)) + } + } + + _, err = tx.Exec("UNLOCK TABLES") + if err != nil { + tx.Rollback() + dbt.Fatal(err) + } + tx.Commit() + + <-errChan +} + +func getKillDSN() string { + cfg, err := ParseDSN(dsn) + if err != nil { + panic(err) + } + cfg.KillQueryOnTimeout = true + return cfg.FormatDSN() +} + func TestMultiResultSet(t *testing.T) { type result struct { values [][]int @@ -242,33 +443,74 @@ func TestPingContext(t *testing.T) { }) } -func TestContextCancelExec(t *testing.T) { +func TestContextCancelNoKill(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) + exec := "INSERT INTO test VALUES(1)" + + testCancelNoKill(dbt, ctx, cancel, exec, func() error { + _, err := dbt.db.ExecContext(ctx, exec) + return err + }) - // Delay execution for just a bit until db.ExecContext has begun. - defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + // Check how many times the query is executed. + var v int + var err error + for i := 0; i != 3; i++ { + err = nil + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + return + } + if v != 1 { + err = fmt.Errorf("expected val to be 1, got %d", v) + } - // This query will be canceled. - startTime := time.Now() - if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { - dbt.Errorf("expected context.Canceled, got %v", err) + if err != nil { + time.Sleep(100 * time.Millisecond) // wait while insert is executed after table lock released + } } - if d := time.Since(startTime); d > 500*time.Millisecond { - dbt.Errorf("too long execution time: %s", d) + if err != nil { + dbt.Error(err) + return } - // Wait for the INSERT query has done. - time.Sleep(time.Second) + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, getKillDSN(), func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + exec := "INSERT INTO test VALUES(1)" + + testCancel(dbt, ctx, cancel, exec, func() error { + _, err := dbt.db.ExecContext(ctx, exec) + return err + }) // Check how many times the query is executed. var v int if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { // TODO: need to kill the query, and v should be 0. - dbt.Errorf("expected val to be 1, got %d", v) + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) } // Context is already canceled, so error should come before execution. @@ -282,39 +524,30 @@ func TestContextCancelExec(t *testing.T) { if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { - dbt.Errorf("expected val to be 1, got %d", v) + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) } }) } func TestContextCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, getKillDSN(), func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) + query := "SELECT 1 FROM test" - // Delay execution for just a bit until db.ExecContext has begun. - defer time.AfterFunc(100*time.Millisecond, cancel).Stop() - - // This query will be canceled. - startTime := time.Now() - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { - dbt.Errorf("expected context.Canceled, got %v", err) - } - if d := time.Since(startTime); d > 500*time.Millisecond { - dbt.Errorf("too long execution time: %s", d) - } - - // Wait for the INSERT query has done. - time.Sleep(time.Second) + testCancel(dbt, ctx, cancel, query, func() error { + _, err := dbt.db.QueryContext(ctx, query) + return err + }) // Check how many times the query is executed. var v int if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { // TODO: need to kill the query, and v should be 0. - dbt.Errorf("expected val to be 1, got %d", v) + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) } // Context is already canceled, so error should come before execution. @@ -326,8 +559,8 @@ func TestContextCancelQuery(t *testing.T) { if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { - dbt.Errorf("expected val to be 1, got %d", v) + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) } }) } @@ -376,95 +609,74 @@ func TestContextCancelPrepare(t *testing.T) { } func TestContextCancelStmtExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, getKillDSN(), func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + exec := "INSERT INTO test VALUES(1)" + + stmt, err := dbt.db.PrepareContext(ctx, exec) if err != nil { dbt.Fatalf("unexpected error: %v", err) } - // Delay execution for just a bit until db.ExecContext has begun. - defer time.AfterFunc(100*time.Millisecond, cancel).Stop() - - // This query will be canceled. - startTime := time.Now() - if _, err := stmt.ExecContext(ctx); err != context.Canceled { - dbt.Errorf("expected context.Canceled, got %v", err) - } - if d := time.Since(startTime); d > 500*time.Millisecond { - dbt.Errorf("too long execution time: %s", d) - } - - // Wait for the INSERT query has done. - time.Sleep(time.Second) + testCancel(dbt, ctx, cancel, exec, func() error { + _, err := stmt.ExecContext(ctx) + return err + }) // Check how many times the query is executed. var v int if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { // TODO: need to kill the query, and v should be 0. + if v != 0 { dbt.Errorf("expected val to be 1, got %d", v) } }) } func TestContextCancelStmtQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, getKillDSN(), func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + query := "SELECT 1 FROM test" + + stmt, err := dbt.db.PrepareContext(ctx, query) if err != nil { dbt.Fatalf("unexpected error: %v", err) } - // Delay execution for just a bit until db.ExecContext has begun. - defer time.AfterFunc(100*time.Millisecond, cancel).Stop() - - // This query will be canceled. - startTime := time.Now() - if _, err := stmt.QueryContext(ctx); err != context.Canceled { - dbt.Errorf("expected context.Canceled, got %v", err) - } - if d := time.Since(startTime); d > 500*time.Millisecond { - dbt.Errorf("too long execution time: %s", d) - } - - // Wait for the INSERT query has done. - time.Sleep(time.Second) + testCancel(dbt, ctx, cancel, query, func() error { + _, err := stmt.QueryContext(ctx) + return err + }) // Check how many times the query is executed. var v int if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } - if v != 1 { // TODO: need to kill the query, and v should be 0. + if v != 0 { dbt.Errorf("expected val to be 1, got %d", v) } }) } func TestContextCancelBegin(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, getKillDSN(), func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) + query := "SELECT 1 FROM test" + tx, err := dbt.db.BeginTx(ctx, nil) if err != nil { dbt.Fatal(err) } - // Delay execution for just a bit until db.ExecContext has begun. - defer time.AfterFunc(100*time.Millisecond, cancel).Stop() - - // This query will be canceled. - startTime := time.Now() - if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { - dbt.Errorf("expected context.Canceled, got %v", err) - } - if d := time.Since(startTime); d > 500*time.Millisecond { - dbt.Errorf("too long execution time: %s", d) - } + testCancel(dbt, ctx, cancel, query, func() error { + _, err := tx.ExecContext(ctx, query) + return err + }) // Transaction is canceled, so expect an error. switch err := tx.Commit(); err { diff --git a/dsn.go b/dsn.go index 47eab6945..ad512b87e 100644 --- a/dsn.go +++ b/dsn.go @@ -57,6 +57,7 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + KillQueryOnTimeout bool // kill query on the server side if context timed out } // NewConfig creates a new Config and sets default values. @@ -254,6 +255,15 @@ func (cfg *Config) FormatDSN() string { } } + if cfg.KillQueryOnTimeout { + if hasParam { + buf.WriteString("&killQueryOnTimeout=true") + } else { + hasParam = true + buf.WriteString("?killQueryOnTimeout=true") + } + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -512,6 +522,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Kill queries on context timeout + case "killQueryOnTimeout": + var isBool bool + cfg.KillQueryOnTimeout, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Strict mode case "strict": panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") diff --git a/packets.go b/packets.go index afc3fcc46..01173b202 100644 --- a/packets.go +++ b/packets.go @@ -180,7 +180,9 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // server version [null terminated string] // connection id [4 bytes] - pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 + idPos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + mc.id = int(binary.LittleEndian.Uint32(data[idPos : idPos+4])) + pos := idPos + 4 // first part of the password cipher [8 bytes] cipher := data[pos : pos+8]