Skip to content

Commit 37fd831

Browse files
committed
buffer: Improve cap consistency.
1 parent 369b5d6 commit 37fd831

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

buffer.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ type buffer struct {
3030
}
3131

3232
func newBuffer(nc net.Conn) buffer {
33-
var b [defaultBufSize]byte
3433
return buffer{
35-
buf: b[:],
34+
buf: make([]byte, defaultBufSize, defaultBufSize),
3635
nc: nc,
3736
}
3837
}
@@ -51,7 +50,8 @@ func (b *buffer) fill(need int) error {
5150
// Maybe keep the org buf slice and swap back?
5251
if need > len(b.buf) {
5352
// Round up to the next multiple of the default size
54-
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
53+
newSize := ((need / defaultBufSize) + 1) * defaultBufSize
54+
newBuf := make([]byte, newSize, newSize)
5555
copy(newBuf, b.buf)
5656
b.buf = newBuf
5757
}
@@ -114,13 +114,12 @@ func (b *buffer) takeBuffer(length int) []byte {
114114
return nil
115115
}
116116

117-
// test (cheap) general case first
118-
if length <= defaultBufSize || length <= cap(b.buf) {
117+
if length <= len(b.buf) {
119118
return b.buf[:length]
120119
}
121120

122121
if length < maxPacketSize {
123-
b.buf = make([]byte, length)
122+
b.buf = make([]byte, length, length)
124123
return b.buf
125124
}
126125
return make([]byte, length)
@@ -145,3 +144,12 @@ func (b *buffer) takeCompleteBuffer() []byte {
145144
}
146145
return b.buf
147146
}
147+
148+
// setGrownBuffer set buf as internal buffer if cap(buf) is larger
149+
// than len(b.buf). It can be used when you took buffer by
150+
// takeCompleteBuffer and append some data to it.
151+
func (b *buffer) setGrownBuffer(buf []byte) {
152+
if cap(buf) >= len(b.buf) {
153+
b.buf = buf[:cap(buf)]
154+
}
155+
}

packets.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10761076
// In that case we must build the data packet with the new values buffer
10771077
if valuesCap != cap(paramValues) {
10781078
data = append(data[:pos], paramValues...)
1079-
mc.buf.buf = data
1079+
mc.buf.setGrownBuffer(data)
10801080
}
10811081

10821082
pos += len(paramValues)

0 commit comments

Comments
 (0)