Skip to content

Commit e39c7ba

Browse files
Simplify the sync SocketBuffer, add type hints (#2543)
1 parent 5e258a1 commit e39c7ba

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Simplify synchronous SocketBuffer state management
12
* Fix string cleanse in Redis Graph
23
* Make PythonParser resumable in case of error (#2510)
34
* Add `timeout=None` in `SentinelConnectionManager.read_response`

redis/connection.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import socket
66
import threading
77
import weakref
8+
from io import SEEK_END
89
from itertools import chain
910
from queue import Empty, Full, LifoQueue
1011
from time import time
11-
from typing import Optional
12+
from typing import Optional, Union
1213
from urllib.parse import parse_qs, unquote, urlparse
1314

1415
from redis.backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163164

164165

165166
class SocketBuffer:
166-
def __init__(self, socket, socket_read_size, socket_timeout):
167+
def __init__(
168+
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
169+
):
167170
self._sock = socket
168171
self.socket_read_size = socket_read_size
169172
self.socket_timeout = socket_timeout
170173
self._buffer = io.BytesIO()
171-
# number of bytes written to the buffer from the socket
172-
self.bytes_written = 0
173-
# number of bytes read from the buffer
174-
self.bytes_read = 0
175174

176-
@property
177-
def length(self):
178-
return self.bytes_written - self.bytes_read
175+
def unread_bytes(self) -> int:
176+
"""
177+
Remaining unread length of buffer
178+
"""
179+
pos = self._buffer.tell()
180+
end = self._buffer.seek(0, SEEK_END)
181+
self._buffer.seek(pos)
182+
return end - pos
179183

180-
def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True):
184+
def _read_from_socket(
185+
self,
186+
length: Optional[int] = None,
187+
timeout: Union[float, object] = SENTINEL,
188+
raise_on_timeout: Optional[bool] = True,
189+
) -> bool:
181190
sock = self._sock
182191
socket_read_size = self.socket_read_size
183-
buf = self._buffer
184-
buf.seek(self.bytes_written)
185192
marker = 0
186193
custom_timeout = timeout is not SENTINEL
187194

195+
buf = self._buffer
196+
current_pos = buf.tell()
197+
buf.seek(0, SEEK_END)
198+
if custom_timeout:
199+
sock.settimeout(timeout)
188200
try:
189-
if custom_timeout:
190-
sock.settimeout(timeout)
191201
while True:
192202
data = self._sock.recv(socket_read_size)
193203
# an empty string indicates the server shutdown the socket
194204
if isinstance(data, bytes) and len(data) == 0:
195205
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
196206
buf.write(data)
197207
data_length = len(data)
198-
self.bytes_written += data_length
199208
marker += data_length
200209

201210
if length is not None and length > marker:
@@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215224
return False
216225
raise ConnectionError(f"Error while reading from socket: {ex.args}")
217226
finally:
227+
buf.seek(current_pos)
218228
if custom_timeout:
219229
sock.settimeout(self.socket_timeout)
220230

221-
def can_read(self, timeout):
222-
return bool(self.length) or self._read_from_socket(
231+
def can_read(self, timeout: float) -> bool:
232+
return bool(self.unread_bytes()) or self._read_from_socket(
223233
timeout=timeout, raise_on_timeout=False
224234
)
225235

226-
def read(self, length):
236+
def read(self, length: int) -> bytes:
227237
length = length + 2 # make sure to read the \r\n terminator
228-
# make sure we've read enough data from the socket
229-
if length > self.length:
230-
self._read_from_socket(length - self.length)
231-
232-
self._buffer.seek(self.bytes_read)
238+
# BufferIO will return less than requested if buffer is short
233239
data = self._buffer.read(length)
234-
self.bytes_read += len(data)
240+
missing = length - len(data)
241+
if missing:
242+
# fill up the buffer and read the remainder
243+
self._read_from_socket(missing)
244+
data += self._buffer.read(missing)
235245
return data[:-2]
236246

237-
def readline(self):
247+
def readline(self) -> bytes:
238248
buf = self._buffer
239-
buf.seek(self.bytes_read)
240249
data = buf.readline()
241250
while not data.endswith(SYM_CRLF):
242251
# there's more data in the socket that we need
243252
self._read_from_socket()
244-
buf.seek(self.bytes_read)
245-
data = buf.readline()
253+
data += buf.readline()
246254

247-
self.bytes_read += len(data)
248255
return data[:-2]
249256

250-
def get_pos(self):
257+
def get_pos(self) -> int:
251258
"""
252259
Get current read position
253260
"""
254-
return self.bytes_read
261+
return self._buffer.tell()
255262

256-
def rewind(self, pos):
263+
def rewind(self, pos: int) -> None:
257264
"""
258265
Rewind the buffer to a specific position, to re-start reading
259266
"""
260-
self.bytes_read = pos
267+
self._buffer.seek(pos)
261268

262-
def purge(self):
269+
def purge(self) -> None:
263270
"""
264271
After a successful read, purge the read part of buffer
265272
"""
266-
unread = self.bytes_written - self.bytes_read
273+
unread = self.unread_bytes()
267274

268275
# Only if we have read all of the buffer do we truncate, to
269276
# reduce the amount of memory thrashing. This heuristic
@@ -276,13 +283,10 @@ def purge(self):
276283
view = self._buffer.getbuffer()
277284
view[:unread] = view[-unread:]
278285
self._buffer.truncate(unread)
279-
self.bytes_written = unread
280-
self.bytes_read = 0
281286
self._buffer.seek(0)
282287

283-
def close(self):
288+
def close(self) -> None:
284289
try:
285-
self.bytes_written = self.bytes_read = 0
286290
self._buffer.close()
287291
except Exception:
288292
# issue #633 suggests the purge/close somehow raised a
@@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
498502
return response
499503

500504

505+
DefaultParser: BaseParser
501506
if HIREDIS_AVAILABLE:
502507
DefaultParser = HiredisParser
503508
else:

0 commit comments

Comments
 (0)