Skip to content

packets: Allow terminating packets of length 0 #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,20 @@ import (

// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var payload []byte
var prevData []byte
for {
// Read packet header
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
}

// Packet Length [24 bit]
// packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)

if pktLen < 1 {
errLog.Print(ErrMalformPkt)
mc.Close()
return nil, driver.ErrBadConn
}

// Check Packet Sync [8 bit]
// check packet sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
Expand All @@ -53,26 +47,38 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
mc.sequence++

// Read packet body [pktLen bytes]
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)−1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
errLog.Print(ErrMalformPkt)
mc.Close()
return nil, driver.ErrBadConn
}

return prevData, nil
}

// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
}

isLastPacket := (pktLen < maxPacketSize)
// return data if this was the last packet
if pktLen < maxPacketSize {
// zero allocations for non-split packets
if prevData == nil {
return data, nil
}

// Zero allocations for non-splitting packets
if isLastPacket && payload == nil {
return data, nil
return append(prevData, data...), nil
}

payload = append(payload, data...)

if isLastPacket {
return payload, nil
}
prevData = append(prevData, data...)
}
}

Expand Down
282 changes: 282 additions & 0 deletions packets_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
"database/sql/driver"
"errors"
"net"
"testing"
"time"
)

var (
errConnClosed = errors.New("connection is closed")
errConnTooManyReads = errors.New("too many reads")
errConnTooManyWrites = errors.New("too many writes")
)

// struct to mock a net.Conn for testing purposes
type mockConn struct {
laddr net.Addr
raddr net.Addr
data []byte
closed bool
read int
written int
reads int
writes int
maxReads int
maxWrites int
}

func (m *mockConn) Read(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}

m.reads++
if m.maxReads > 0 && m.reads > m.maxReads {
return 0, errConnTooManyReads
}

n = copy(b, m.data)
m.read += n
m.data = m.data[n:]
return
}
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}

m.writes++
if m.maxWrites > 0 && m.writes > m.maxWrites {
return 0, errConnTooManyWrites
}

n = len(b)
m.written += n
return
}
func (m *mockConn) Close() error {
m.closed = true
return nil
}
func (m *mockConn) LocalAddr() net.Addr {
return m.laddr
}
func (m *mockConn) RemoteAddr() net.Addr {
return m.raddr
}
func (m *mockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetWriteDeadline(t time.Time) error {
return nil
}

// make sure mockConn implements the net.Conn interface
var _ net.Conn = new(mockConn)

func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}

conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 1 {
t.Fatalf("unexpected packet lenght: expected %d, got %d", 1, len(packet))
}
if packet[0] != 0xff {
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
}
}

func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}

// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
mc.sequence = 1
_, err := mc.readPacket()
if err != ErrPktSync {
t.Errorf("expected ErrPktSync, got %v", err)
}

// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)

// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
}
}

func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}

data := make([]byte, maxPacketSize*2+4*3)
const pkt2ofs = maxPacketSize + 4
const pkt3ofs = 2 * (maxPacketSize + 4)

// case 1: payload has length maxPacketSize
data = data[:pkt2ofs+4]

// 1st packet has maxPacketSize length and sequence id 0
// ff ff ff 00 ...
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff

// mark the payload start and end of 1st packet so that we can check if the
// content was correctly appended
data[4] = 0x11
data[maxPacketSize+3] = 0x22

// 2nd packet has payload length 0 and squence id 1
// 00 00 00 01
data[pkt2ofs+3] = 0x01

conn.data = data
conn.maxReads = 3
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize {
t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize-1] != 0x22 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
}

// case 2: payload has length which is a multiple of maxPacketSize
data = data[:cap(data)]

// 2nd packet now has maxPacketSize length
data[pkt2ofs] = 0xff
data[pkt2ofs+1] = 0xff
data[pkt2ofs+2] = 0xff

// mark the payload start and end of the 2nd packet
data[pkt2ofs+4] = 0x33
data[pkt2ofs+maxPacketSize+3] = 0x44

// 3rd packet has payload length 0 and squence id 2
// 00 00 00 02
data[pkt3ofs+3] = 0x02

conn.data = data
conn.reads = 0
conn.maxReads = 5
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 2*maxPacketSize {
t.Fatalf("unexpected packet lenght: expected %d, got %d", 2*maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[2*maxPacketSize-1] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
}

// case 3: payload has a length larger maxPacketSize, which is not an exact
// multiple of it
data = data[:pkt2ofs+4+42]
data[pkt2ofs] = 0x2a
data[pkt2ofs+1] = 0x00
data[pkt2ofs+2] = 0x00
data[pkt2ofs+4+41] = 0x44

conn.data = data
conn.reads = 0
conn.maxReads = 4
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize+42 {
t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize+42, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize+41] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
}
}

func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}

// illegal empty (stand-alone) packet
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
conn.maxReads = 1
_, err := mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
}

// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)

// fail to read header
conn.closed = true
_, err = mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
}

// reset
conn.closed = false
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)

// fail to read body
conn.maxReads = 1
_, err = mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
}
}