diff --git a/benchmark_test.go b/benchmark_test.go index b246f4ac..a8fa53f6 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -479,34 +479,32 @@ func BenchmarkReceiveMetadata(b *testing.B) { ) defer db.Close() - b.Run("query", func(b *testing.B) { - db.SetMaxIdleConns(0) - db.SetMaxIdleConns(1) - - // Create a slice to scan all columns - values := make([]any, 1000) - valuePtrs := make([]any, 1000) - for j := range values { - valuePtrs[j] = &values[j] - } + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + + // Create a slice to scan all columns + values := make([]any, 1000) + valuePtrs := make([]any, 1000) + for j := range values { + valuePtrs[j] = &values[j] + } - b.ReportAllocs() - b.ResetTimer() + // Prepare a SELECT query to retrieve metadata + stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) + defer stmt.Close() - // Prepare a SELECT query to retrieve metadata - stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) - defer stmt.Close() + b.ReportAllocs() + b.ResetTimer() - // Benchmark metadata retrieval - for range b.N { - rows := tb.checkRows(stmt.Query()) + // Benchmark metadata retrieval + for b.Loop() { + rows := tb.checkRows(stmt.Query()) - rows.Next() - // Scan the row - err := rows.Scan(valuePtrs...) - tb.check(err) + rows.Next() + // Scan the row + err := rows.Scan(valuePtrs...) + tb.check(err) - rows.Close() - } - }) + rows.Close() + } } diff --git a/connection.go b/connection.go index 58c763fa..5648e47d 100644 --- a/connection.go +++ b/connection.go @@ -231,7 +231,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if columnCount > 0 { if mc.extCapabilities&clientCacheMetadata != 0 { - if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { + if stmt.columns, err = mc.readColumns(int(columnCount), nil); err != nil { return nil, err } } else { @@ -448,7 +448,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } // Columns - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen, nil) return rows, err } diff --git a/packets.go b/packets.go index 1319f9e6..b8f06126 100644 --- a/packets.go +++ b/packets.go @@ -702,8 +702,11 @@ func (mc *okHandler) handleOkPacket(data []byte) error { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { +func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) { columns := make([]mysqlField, count) + if len(old) != count { + old = nil + } for i := range count { data, err := mc.readPacket() @@ -731,7 +734,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { return nil, err } pos += n - columns[i].tableName = string(tableName) + if old != nil && old[i].tableName == string(tableName) { + // avoid allocating new string + columns[i].tableName = old[i].tableName + } else { + columns[i].tableName = string(tableName) + } } else { n, err = skipLengthEncodedString(data[pos:]) if err != nil { @@ -752,7 +760,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } - columns[i].name = string(name) + if old != nil && old[i].name == string(name) { + // avoid allocating new string + columns[i].name = old[i].name + } else { + columns[i].name = string(name) + } pos += n // Original name [len coded string] diff --git a/rows.go b/rows.go index e41fda6f..190e75f9 100644 --- a/rows.go +++ b/rows.go @@ -186,7 +186,7 @@ func (rows *binaryRows) NextResultSet() error { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } @@ -208,7 +208,7 @@ func (rows *textRows) NextResultSet() (err error) { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } diff --git a/statement.go b/statement.go index 0f6c65a3..2db8960e 100644 --- a/statement.go +++ b/statement.go @@ -74,7 +74,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Columns if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { // we can not skip column metadata because next stmt.Query() may use it. - if stmt.columns, err = mc.readColumns(resLen); err != nil { + if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { return nil, err } } else { @@ -125,7 +125,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc if metadataFollows { - if rows.rs.columns, err = mc.readColumns(resLen); err != nil { + if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { return nil, err } stmt.columns = rows.rs.columns