From b8161abd2540a6dfc56fa1787bcc9dec3c3ce7b4 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 15 May 2013 19:17:31 +0200 Subject: [PATCH 1/8] remove .gitattributes file is obsolete --- .gitattributes | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index b0ff81123..000000000 --- a/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -README.md merge=ours From c4f805ae6fb28e79a7efe63f80421fb02c2d249e Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 31 May 2013 18:47:33 +0200 Subject: [PATCH 2/8] fmt fix --- driver_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver_test.go b/driver_test.go index b75e50d21..ec0b45d00 100644 --- a/driver_test.go +++ b/driver_test.go @@ -797,7 +797,7 @@ func TestStrict(t *testing.T) { for i := range queries { stmt, err = dbt.db.Prepare(queries[i].in) if err != nil { - dbt.Error("Error on preparing query %: ", queries[i].in, err.Error()) + dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error()) } _, err = stmt.Exec() @@ -805,7 +805,7 @@ func TestStrict(t *testing.T) { err = stmt.Close() if err != nil { - dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error()) + dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error()) } } }) From 261f309ac246741c69b7815e1882b9a0ded31cf3 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Fri, 31 May 2013 18:47:44 +0200 Subject: [PATCH 3/8] Go 1.1 spec changes --- errors.go | 1 - packets.go | 3 --- 2 files changed, 4 deletions(-) diff --git a/errors.go b/errors.go index 1769d40f6..20003e086 100644 --- a/errors.go +++ b/errors.go @@ -102,5 +102,4 @@ func (mc *mysqlConn) getWarnings() (err error) { return } } - return } diff --git a/packets.go b/packets.go index 64482bd6b..cefad97b8 100644 --- a/packets.go +++ b/packets.go @@ -569,8 +569,6 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { i++ } - - return } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -636,7 +634,6 @@ func (mc *mysqlConn) readUntilEOF() (err error) { } return // Err or EOF } - return } /****************************************************************************** From b0d08caea20185d0410bd9c24f4c8c121e68987b Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 1 Jun 2013 17:15:02 +0200 Subject: [PATCH 4/8] cleanup config param handling --- connection.go | 24 +++++++++++------------- infile.go | 2 +- packets.go | 2 +- utils.go | 8 ++++++++ utils_test.go | 18 +++++++++--------- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/connection.go b/connection.go index 8436f0383..a62db429a 100644 --- a/connection.go +++ b/connection.go @@ -35,15 +35,17 @@ type mysqlConn struct { } type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - timeout time.Duration - tls *tls.Config + user string + passwd string + net string + addr string + dbname string + params map[string]string + loc *time.Location + timeout time.Duration + tls *tls.Config + allowAllFiles bool + clientFoundRows bool } // Handles parameters set in DSN @@ -64,10 +66,6 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // handled elsewhere - case "allowAllFiles", "clientFoundRows": - continue - // time.Time parsing case "parseTime": mc.parseTime = readBool(val) diff --git a/infile.go b/infile.go index 635dbd572..1f2faf461 100644 --- a/infile.go +++ b/infile.go @@ -74,7 +74,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } } else { // File name = strings.Trim(name, `"`) - if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` { + if mc.cfg.allowAllFiles || fileRegister[name] { rdr, err = os.Open(name) } else { err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name) diff --git a/packets.go b/packets.go index cefad97b8..34195a083 100644 --- a/packets.go +++ b/packets.go @@ -215,7 +215,7 @@ func (mc *mysqlConn) writeAuthPacket() error { clientLocalFiles | mc.flags&clientLongFlag - if _, ok := mc.cfg.params["clientFoundRows"]; ok { + if mc.cfg.clientFoundRows { clientFlags |= clientFoundRows } diff --git a/utils.go b/utils.go index d566705ab..097ecd0aa 100644 --- a/utils.go +++ b/utils.go @@ -124,6 +124,14 @@ func parseDSN(dsn string) (cfg *config, err error) { // cfg params switch value := param[1]; param[0] { + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + cfg.allowAllFiles = readBool(value) + + // Switch "rowsAffected" mode + case "clientFoundRows": + cfg.clientFoundRows = readBool(value) + // Time Location case "loc": cfg.loc, err = time.LoadLocation(value) diff --git a/utils_test.go b/utils_test.go index 8068977ac..836790061 100644 --- a/utils_test.go +++ b/utils_test.go @@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls: allowAllFiles:true clientFoundRows:true}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls: allowAllFiles:false clientFoundRows:false}", time.UTC}, } var cfg *config From e44f1b6291f8c6c80a425c3129d92656c6e38800 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 1 Jun 2013 17:43:50 +0200 Subject: [PATCH 5/8] refactor driver tests - Go 1.1 API - moved TestFoundRows - Refactored charset tests - err formatting - ... --- driver_test.go | 244 ++++++++++++++++++++++--------------------------- 1 file changed, 108 insertions(+), 136 deletions(-) diff --git a/driver_test.go b/driver_test.go index ec0b45d00..7be65a476 100644 --- a/driver_test.go +++ b/driver_test.go @@ -13,7 +13,6 @@ import ( ) var ( - charset string dsn string netAddr string available bool @@ -42,7 +41,6 @@ func init() { prot := env("MYSQL_TEST_PROT", "tcp") addr := env("MYSQL_TEST_ADDR", "localhost:3306") dbname := env("MYSQL_TEST_DBNAME", "gotest") - charset = "charset=utf8" netAddr = fmt.Sprintf("%s(%s)", prot, addr) dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) c, err := net.Dial(prot, addr) @@ -57,15 +55,14 @@ type DBTest struct { db *sql.DB } -func runTests(t *testing.T, name, dsn string, tests ...func(dbt *DBTest)) { +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { - t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name) - return + t.Skipf("MySQL-Server not running on %s", netAddr) } db, err := sql.Open("mysql", dsn) if err != nil { - t.Fatalf("Error connecting: %v", err) + t.Fatalf("Error connecting: %s", err.Error()) } defer db.Close() @@ -82,7 +79,7 @@ func (dbt *DBTest) fail(method, query string, err error) { if len(query) > 300 { query = "[query too large to print]" } - dbt.Fatalf("Error on %s %s: %v", method, query, err) + dbt.Fatalf("Error on %s %s: %s", method, query, err.Error()) } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { @@ -102,32 +99,26 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) } func TestCharset(t *testing.T) { - mustSetCharset := func(charsetParam, expected string) { - db, err := sql.Open("mysql", dsn+"&"+charsetParam) - if err != nil { - t.Fatalf("Error on Open: %v", err) - } - defer db.Close() - - dbt := &DBTest{t, db} - rows := dbt.mustQuery("SELECT @@character_set_connection") - defer rows.Close() + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } - if !rows.Next() { - dbt.Fatalf("Error getting connection charset: %v", err) - } + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() - var got string - rows.Scan(&got) + if !rows.Next() { + dbt.Fatalf("Error getting connection charset: %s", rows.Err()) + } - if got != expected { - dbt.Fatalf("Expected connection charset %s but got %s", expected, got) - } - } + var got string + rows.Scan(&got) - if !available { - t.Logf("MySQL-Server not running on %s. Skipping TestCharset", netAddr) - return + if got != expected { + dbt.Fatalf("Expected connection charset %s but got %s", expected, got) + } + }) } // non utf8 test @@ -142,26 +133,18 @@ func TestCharset(t *testing.T) { } func TestFailingCharset(t *testing.T) { - if !available { - t.Logf("MySQL-Server not running on %s. Skipping TestFailingCharset", netAddr) - return - } - db, err := sql.Open("mysql", dsn+"&charset=none") - if err != nil { - t.Fatalf("Error on Open: %v", err) - } - defer db.Close() - - // run query to really establish connection... - _, err = db.Exec("SELECT 1") - if err == nil { - db.Close() - t.Fatalf("Connection must not succeed without a valid charset") - } + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("Connection must not succeed without a valid charset") + } + }) } func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, "TestRawBytesResultExceedsBuffer", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { // defaultBufSize from buffer.go expected := strings.Repeat("abc", defaultBufSize) rows := dbt.mustQuery("SELECT '" + expected + "'") @@ -178,7 +161,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, "TestCRUD", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { // Create Table dbt.mustExec("CREATE TABLE test (value BOOL)") @@ -193,7 +176,7 @@ func TestCRUD(t *testing.T) { res := dbt.mustExec("INSERT INTO test VALUES (1)") count, err := res.RowsAffected() if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("Expected 1 affected row, got %d", count) @@ -201,7 +184,7 @@ func TestCRUD(t *testing.T) { id, err := res.LastInsertId() if err != nil { - dbt.Fatalf("res.LastInsertId() returned error: %v", err) + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) } if id != 0 { dbt.Fatalf("Expected InsertID 0, got %d", id) @@ -226,7 +209,7 @@ func TestCRUD(t *testing.T) { res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) count, err = res.RowsAffected() if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("Expected 1 affected row, got %d", count) @@ -251,7 +234,7 @@ func TestCRUD(t *testing.T) { res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) count, err = res.RowsAffected() if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { dbt.Fatalf("Expected 1 affected row, got %d", count) @@ -261,7 +244,7 @@ func TestCRUD(t *testing.T) { res = dbt.mustExec("DELETE FROM test") count, err = res.RowsAffected() if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 0 { dbt.Fatalf("Expected 0 affected row, got %d", count) @@ -270,7 +253,7 @@ func TestCRUD(t *testing.T) { } func TestInt(t *testing.T) { - runTests(t, "TestInt", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} in := int64(42) var out int64 @@ -317,7 +300,7 @@ func TestInt(t *testing.T) { } func TestFloat(t *testing.T) { - runTests(t, "TestFloat", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 @@ -340,7 +323,7 @@ func TestFloat(t *testing.T) { } func TestString(t *testing.T) { - runTests(t, "TestString", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" var out string @@ -380,7 +363,7 @@ func TestString(t *testing.T) { err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) if err != nil { - dbt.Fatalf("Error on BLOB-Query: %v", err) + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { dbt.Errorf("BLOB: %s != %s", in, out) } @@ -429,7 +412,7 @@ func TestDateTime(t *testing.T) { dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) { var sOut string if err := rows.Scan(&sOut); err != nil { - dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err) + dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error()) } else if test.sOut != sOut { dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut) } @@ -438,7 +421,7 @@ func TestDateTime(t *testing.T) { dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) { var tOut time.Time if err := rows.Scan(&tOut); err != nil { - dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err) + dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error()) } else if test.tOut != tOut || test.tIsZero != tOut.IsZero() { dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero()) } @@ -460,8 +443,8 @@ func TestDateTime(t *testing.T) { s.test(dbt, rows, test, sqltype, s.vartype, mode) } else { if err := rows.Err(); err != nil { - dbt.Errorf("%s (%s %s): %v", - sqltype, s.vartype, mode, err) + dbt.Errorf("%s (%s %s): %s", + sqltype, s.vartype, mode, err.Error()) } else { dbt.Errorf("%s (%s %s): no data", sqltype, s.vartype, mode) @@ -476,12 +459,12 @@ func TestDateTime(t *testing.T) { timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES" for _, v := range setups { s = v - runTests(t, "TestDateTime", timeDsn+s.dsnSuffix, testTime) + runTests(t, timeDsn+s.dsnSuffix, testTime) } } func TestNULL(t *testing.T) { - runTests(t, "TestNULL", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { nullStmt, err := dbt.db.Prepare("SELECT NULL") if err != nil { dbt.Fatal(err) @@ -597,7 +580,7 @@ func TestNULL(t *testing.T) { } func TestLongData(t *testing.T) { - runTests(t, "TestLongData", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { var maxAllowedPacketSize int err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) if err != nil { @@ -658,7 +641,7 @@ func TestLongData(t *testing.T) { } func TestLoadData(t *testing.T) { - runTests(t, "TestLoadData", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { verifyLoadDataResult := func() { rows, err := dbt.db.Query("SELECT * FROM test") if err != nil { @@ -744,10 +727,55 @@ func TestLoadData(t *testing.T) { }) } +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + func TestStrict(t *testing.T) { // ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES" - runTests(t, "TestStrict", relaxedDsn, func(dbt *DBTest) { + runTests(t, relaxedDsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))") var queries = [...]struct { @@ -812,21 +840,10 @@ func TestStrict(t *testing.T) { } func TestTLS(t *testing.T) { - runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) { - /* TODO: GO 1.1 API */ - /*if err := dbt.db.Ping(); err != nil { - if err == errNoTLS { - dbt.Skip("Server does not support TLS. Skipping TestTLS") - } else { - dbt.Fatalf("Error on Ping: %s", err.Error()) - } - }*/ - - /* GO 1.0 API */ - if _, err := dbt.db.Exec("DO 1"); err != nil { + runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { if err == errNoTLS { - dbt.Log("Server does not support TLS. Skipping TestTLS") - return + dbt.Skip("Server does not support TLS") } else { dbt.Fatalf("Error on Ping: %s", err.Error()) } @@ -850,7 +867,7 @@ func TestTLS(t *testing.T) { // Special cases func TestRowsClose(t *testing.T) { - runTests(t, "TestRowsClose", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { rows, err := dbt.db.Query("SELECT 1") if err != nil { dbt.Fatal(err) @@ -875,7 +892,7 @@ func TestRowsClose(t *testing.T) { // dangling statements // http://code.google.com/p/go/issues/detail?id=3865 func TestCloseStmtBeforeRows(t *testing.T) { - runTests(t, "TestCloseStmtBeforeRows", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { stmt, err := dbt.db.Prepare("SELECT 1") if err != nil { dbt.Fatal(err) @@ -904,7 +921,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { var out bool err = rows.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -916,7 +933,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { // It is valid to have multiple Rows for the same Stmt // http://code.google.com/p/go/issues/detail?id=3734 func TestStmtMultiRows(t *testing.T) { - runTests(t, "TestStmtMultiRows", dsn, func(dbt *DBTest) { + runTests(t, dsn, func(dbt *DBTest) { stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") if err != nil { dbt.Fatal(err) @@ -949,7 +966,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows1.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -966,7 +983,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows2.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -984,7 +1001,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows1.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) } if out != false { dbt.Errorf("false != %t", out) @@ -1009,7 +1026,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows2.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) } if out != false { dbt.Errorf("false != %t", out) @@ -1028,14 +1045,14 @@ func TestStmtMultiRows(t *testing.T) { func TestConcurrent(t *testing.T) { if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true { - t.Log("CONCURRENT env var not set. Skipping TestConcurrent") - return + t.Skip("CONCURRENT env var not set") } - runTests(t, "TestConcurrent", dsn, func(dbt *DBTest) { + + runTests(t, dsn, func(dbt *DBTest) { var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) if err != nil { - dbt.Fatalf("%v", err) + dbt.Fatalf("%s", err.Error()) } dbt.Logf("Testing up to %d concurrent connections \r\n", max) canStop := false @@ -1076,51 +1093,6 @@ func TestConcurrent(t *testing.T) { }) } -func TestFoundRows(t *testing.T) { - runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 2 { - dbt.Fatalf("Expected 2 affected rows, got %d", count) - } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 2 { - dbt.Fatalf("Expected 2 affected rows, got %d", count) - } - }) - runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 2 { - dbt.Fatalf("Expected 2 matched rows, got %d", count) - } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 3 { - dbt.Fatalf("Expected 3 matched rows, got %d", count) - } - }) -} - // BENCHMARKS var sample []byte From e19880b5d2589d20e211cd2ab288880416887ed2 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 1 Jun 2013 20:00:50 +0200 Subject: [PATCH 6/8] use net.Dialer --- driver.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/driver.go b/driver.go index da8ea571f..9ce6f6aaf 100644 --- a/driver.go +++ b/driver.go @@ -33,13 +33,8 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) { } // Connect to Server - if mc.cfg.timeout > 0 { // with timeout - if err == nil { - mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout) - } - } else { // no timeout - mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr) - } + nd := net.Dialer{Timeout: mc.cfg.timeout} + mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr) if err != nil { return nil, err } From 0571e1eb67cfed69161d5556364936aabd15d266 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sun, 2 Jun 2013 07:51:42 +0200 Subject: [PATCH 7/8] cache dbname length --- packets.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packets.go b/packets.go index 34195a083..b26caad4a 100644 --- a/packets.go +++ b/packets.go @@ -231,9 +231,9 @@ func (mc *mysqlConn) writeAuthPacket() error { pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) // To specify a db name - if len(mc.cfg.dbname) > 0 { + if n := len(mc.cfg.dbname); n > 0 { clientFlags |= clientConnectWithDB - pktLen += len(mc.cfg.dbname) + 1 + pktLen += n + 1 } // Calculate packet length and make buffer with that size From c385eff4b9cfbcb87779a0c4d219937e0a464295 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Mon, 10 Jun 2013 19:20:02 +0200 Subject: [PATCH 8/8] Require Go 1.1 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e5db9de9c..8cd6ac38c 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa * Optional `time.Time` parsing ## Requirements - * Go 1.0.3 or higher + * Go 1.1 or higher (use [v1.0](https://github.com/go-sql-driver/mysql/tags) for Go 1.0.x) * MySQL (Version 4.1 or higher), MariaDB or Percona Server ---------------------------------------