Skip to content

Commit eb95823

Browse files
committed
Simplify the sync SocketBuffer, add type hints
1 parent a9ef0c5 commit eb95823

File tree

2 files changed

+40
-36
lines changed

2 files changed

+40
-36
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Simplify synchronous SocketBuffer state management
12
* Make PythonParser resumable in case of error (#2510)
23
* Add `timeout=None` in `SentinelConnectionManager.read_response`
34
* Documentation fix: password protected socket connection (#2374)

redis/connection.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import socket
66
import threading
77
import weakref
8+
from io import SEEK_END
89
from itertools import chain
910
from queue import Empty, Full, LifoQueue
1011
from time import time
11-
from typing import Optional
12+
from typing import Optional, Union
1213
from urllib.parse import parse_qs, unquote, urlparse
1314

1415
from redis.backoff import NoBackoff
@@ -163,39 +164,47 @@ def parse_error(self, response):
163164

164165

165166
class SocketBuffer:
166-
def __init__(self, socket, socket_read_size, socket_timeout):
167+
def __init__(
168+
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
169+
):
167170
self._sock = socket
168171
self.socket_read_size = socket_read_size
169172
self.socket_timeout = socket_timeout
170173
self._buffer = io.BytesIO()
171-
# number of bytes written to the buffer from the socket
172-
self.bytes_written = 0
173-
# number of bytes read from the buffer
174-
self.bytes_read = 0
175174

176-
@property
177-
def length(self):
178-
return self.bytes_written - self.bytes_read
175+
def unread_bytes(self) -> int:
176+
"""
177+
Remaining unread length of buffer
178+
"""
179+
pos = self._buffer.tell()
180+
end = self._buffer.seek(0, SEEK_END)
181+
self._buffer.seek(pos)
182+
return end - pos
179183

180-
def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True):
184+
def _read_from_socket(
185+
self,
186+
length: Optional[int] = None,
187+
timeout: Union[float, object] = SENTINEL,
188+
raise_on_timeout: Optional[bool] = True,
189+
) -> bool:
181190
sock = self._sock
182191
socket_read_size = self.socket_read_size
183-
buf = self._buffer
184-
buf.seek(self.bytes_written)
185192
marker = 0
186193
custom_timeout = timeout is not SENTINEL
187194

195+
buf = self._buffer
196+
current_pos = buf.tell()
197+
buf.seek(0, SEEK_END)
198+
if custom_timeout:
199+
sock.settimeout(timeout)
188200
try:
189-
if custom_timeout:
190-
sock.settimeout(timeout)
191201
while True:
192202
data = self._sock.recv(socket_read_size)
193203
# an empty string indicates the server shutdown the socket
194204
if isinstance(data, bytes) and len(data) == 0:
195205
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
196206
buf.write(data)
197207
data_length = len(data)
198-
self.bytes_written += data_length
199208
marker += data_length
200209

201210
if length is not None and length > marker:
@@ -215,55 +224,51 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
215224
return False
216225
raise ConnectionError(f"Error while reading from socket: {ex.args}")
217226
finally:
227+
buf.seek(current_pos)
218228
if custom_timeout:
219229
sock.settimeout(self.socket_timeout)
220230

221-
def can_read(self, timeout):
222-
return bool(self.length) or self._read_from_socket(
231+
def can_read(self, timeout: float) -> bool:
232+
return bool(self.unread_bytes()) or self._read_from_socket(
223233
timeout=timeout, raise_on_timeout=False
224234
)
225235

226-
def read(self, length):
236+
def read(self, length: int) -> bytes:
227237
length = length + 2 # make sure to read the \r\n terminator
228238
# make sure we've read enough data from the socket
229-
if length > self.length:
230-
self._read_from_socket(length - self.length)
239+
if length > self.unread_bytes:
240+
self._read_from_socket(length - self.unread_bytes)
231241

232-
self._buffer.seek(self.bytes_read)
233242
data = self._buffer.read(length)
234-
self.bytes_read += len(data)
235243
return data[:-2]
236244

237-
def readline(self):
245+
def readline(self) -> bytes:
238246
buf = self._buffer
239-
buf.seek(self.bytes_read)
240247
data = buf.readline()
241248
while not data.endswith(SYM_CRLF):
242249
# there's more data in the socket that we need
243250
self._read_from_socket()
244-
buf.seek(self.bytes_read)
245251
data = buf.readline()
246252

247-
self.bytes_read += len(data)
248253
return data[:-2]
249254

250-
def get_pos(self):
255+
def get_pos(self) -> int:
251256
"""
252257
Get current read position
253258
"""
254-
return self.bytes_read
259+
return self._buffer.tell()
255260

256-
def rewind(self, pos):
261+
def rewind(self, pos: int) -> None:
257262
"""
258263
Rewind the buffer to a specific position, to re-start reading
259264
"""
260-
self.bytes_read = pos
265+
self._buffer.seek(pos)
261266

262-
def purge(self):
267+
def purge(self) -> None:
263268
"""
264269
After a successful read, purge the read part of buffer
265270
"""
266-
unread = self.bytes_written - self.bytes_read
271+
unread = self.unread_bytes()
267272

268273
# Only if we have read all of the buffer do we truncate, to
269274
# reduce the amount of memory thrashing. This heuristic
@@ -276,13 +281,10 @@ def purge(self):
276281
view = self._buffer.getbuffer()
277282
view[:unread] = view[-unread:]
278283
self._buffer.truncate(unread)
279-
self.bytes_written = unread
280-
self.bytes_read = 0
281284
self._buffer.seek(0)
282285

283-
def close(self):
286+
def close(self) -> None:
284287
try:
285-
self.bytes_written = self.bytes_read = 0
286288
self._buffer.close()
287289
except Exception:
288290
# issue #633 suggests the purge/close somehow raised a
@@ -498,6 +500,7 @@ def read_response(self, disable_decoding=False):
498500
return response
499501

500502

503+
DefaultParser: BaseParser
501504
if HIREDIS_AVAILABLE:
502505
DefaultParser = HiredisParser
503506
else:

0 commit comments

Comments
 (0)