Skip to content

Commit 16270e4

Browse files
Remove the superflous SocketBuffer from asyncio PythonParser (#2418)
* Remove buffering from asyncio SocketBuffer and rely on on the underlying StreamReader * Skip the use of SocketBuffer in PythonParser * Remove SocketBuffer altogether * Code cleanup * Fix unittest mocking when SocketBuffer is gone
1 parent 842634e commit 16270e4

File tree

2 files changed

+42
-137
lines changed

2 files changed

+42
-137
lines changed

redis/asyncio/connection.py

Lines changed: 37 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import copy
33
import enum
44
import inspect
5-
import io
65
import os
76
import socket
87
import ssl
@@ -141,7 +140,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT:
141140
class BaseParser:
142141
"""Plain Python parsing class"""
143142

144-
__slots__ = "_stream", "_buffer", "_read_size"
143+
__slots__ = "_stream", "_read_size"
145144

146145
EXCEPTION_CLASSES: ExceptionMappingT = {
147146
"ERR": {
@@ -171,7 +170,6 @@ class BaseParser:
171170

172171
def __init__(self, socket_read_size: int):
173172
self._stream: Optional[asyncio.StreamReader] = None
174-
self._buffer: Optional[SocketBuffer] = None
175173
self._read_size = socket_read_size
176174

177175
def __del__(self):
@@ -206,127 +204,6 @@ async def read_response(
206204
raise NotImplementedError()
207205

208206

209-
class SocketBuffer:
210-
"""Async-friendly re-impl of redis-py's SocketBuffer.
211-
212-
TODO: We're currently passing through two buffers,
213-
the asyncio.StreamReader and this. I imagine we can reduce the layers here
214-
while maintaining compliance with prior art.
215-
"""
216-
217-
def __init__(
218-
self,
219-
stream_reader: asyncio.StreamReader,
220-
socket_read_size: int,
221-
):
222-
self._stream: Optional[asyncio.StreamReader] = stream_reader
223-
self.socket_read_size = socket_read_size
224-
self._buffer: Optional[io.BytesIO] = io.BytesIO()
225-
# number of bytes written to the buffer from the socket
226-
self.bytes_written = 0
227-
# number of bytes read from the buffer
228-
self.bytes_read = 0
229-
230-
@property
231-
def length(self):
232-
return self.bytes_written - self.bytes_read
233-
234-
async def _read_from_socket(self, length: Optional[int] = None) -> bool:
235-
buf = self._buffer
236-
if buf is None or self._stream is None:
237-
raise RedisError("Buffer is closed.")
238-
buf.seek(self.bytes_written)
239-
marker = 0
240-
241-
while True:
242-
data = await self._stream.read(self.socket_read_size)
243-
# an empty string indicates the server shutdown the socket
244-
if isinstance(data, bytes) and len(data) == 0:
245-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
246-
buf.write(data)
247-
data_length = len(data)
248-
self.bytes_written += data_length
249-
marker += data_length
250-
251-
if length is not None and length > marker:
252-
continue
253-
return True
254-
255-
async def can_read_destructive(self) -> bool:
256-
if self.length:
257-
return True
258-
try:
259-
async with async_timeout.timeout(0):
260-
return await self._read_from_socket()
261-
except asyncio.TimeoutError:
262-
return False
263-
264-
async def read(self, length: int) -> bytes:
265-
length = length + 2 # make sure to read the \r\n terminator
266-
# make sure we've read enough data from the socket
267-
if length > self.length:
268-
await self._read_from_socket(length - self.length)
269-
270-
if self._buffer is None:
271-
raise RedisError("Buffer is closed.")
272-
273-
self._buffer.seek(self.bytes_read)
274-
data = self._buffer.read(length)
275-
self.bytes_read += len(data)
276-
277-
# purge the buffer when we've consumed it all so it doesn't
278-
# grow forever
279-
if self.bytes_read == self.bytes_written:
280-
self.purge()
281-
282-
return data[:-2]
283-
284-
async def readline(self) -> bytes:
285-
buf = self._buffer
286-
if buf is None:
287-
raise RedisError("Buffer is closed.")
288-
289-
buf.seek(self.bytes_read)
290-
data = buf.readline()
291-
while not data.endswith(SYM_CRLF):
292-
# there's more data in the socket that we need
293-
await self._read_from_socket()
294-
buf.seek(self.bytes_read)
295-
data = buf.readline()
296-
297-
self.bytes_read += len(data)
298-
299-
# purge the buffer when we've consumed it all so it doesn't
300-
# grow forever
301-
if self.bytes_read == self.bytes_written:
302-
self.purge()
303-
304-
return data[:-2]
305-
306-
def purge(self):
307-
if self._buffer is None:
308-
raise RedisError("Buffer is closed.")
309-
310-
self._buffer.seek(0)
311-
self._buffer.truncate()
312-
self.bytes_written = 0
313-
self.bytes_read = 0
314-
315-
def close(self):
316-
try:
317-
self.purge()
318-
self._buffer.close()
319-
except Exception:
320-
# issue #633 suggests the purge/close somehow raised a
321-
# BadFileDescriptor error. Perhaps the client ran out of
322-
# memory or something else? It's probably OK to ignore
323-
# any error being raised from purge/close since we're
324-
# removing the reference to the instance below.
325-
pass
326-
self._buffer = None
327-
self._stream = None
328-
329-
330207
class PythonParser(BaseParser):
331208
"""Plain Python parsing class"""
332209

@@ -342,27 +219,29 @@ def on_connect(self, connection: "Connection"):
342219
if self._stream is None:
343220
raise RedisError("Buffer is closed.")
344221

345-
self._buffer = SocketBuffer(self._stream, self._read_size)
346222
self.encoder = connection.encoder
347223

348224
def on_disconnect(self):
349225
"""Called when the stream disconnects"""
350226
if self._stream is not None:
351227
self._stream = None
352-
if self._buffer is not None:
353-
self._buffer.close()
354-
self._buffer = None
355228
self.encoder = None
356229

357-
async def can_read_destructive(self):
358-
return self._buffer and bool(await self._buffer.can_read_destructive())
230+
async def can_read_destructive(self) -> bool:
231+
if self._stream is None:
232+
raise RedisError("Buffer is closed.")
233+
try:
234+
async with async_timeout.timeout(0):
235+
return await self._stream.read(1)
236+
except asyncio.TimeoutError:
237+
return False
359238

360239
async def read_response(
361240
self, disable_decoding: bool = False
362241
) -> Union[EncodableT, ResponseError, None]:
363-
if not self._buffer or not self.encoder:
242+
if not self._stream or not self.encoder:
364243
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
365-
raw = await self._buffer.readline()
244+
raw = await self._readline()
366245
if not raw:
367246
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
368247
response: Any
@@ -395,7 +274,7 @@ async def read_response(
395274
length = int(response)
396275
if length == -1:
397276
return None
398-
response = await self._buffer.read(length)
277+
response = await self._read(length)
399278
# multi-bulk response
400279
elif byte == b"*":
401280
length = int(response)
@@ -408,6 +287,31 @@ async def read_response(
408287
response = self.encoder.decode(response)
409288
return response
410289

290+
async def _read(self, length: int) -> bytes:
291+
"""
292+
Read `length` bytes of data. These are assumed to be followed
293+
by a '\r\n' terminator which is subsequently discarded.
294+
"""
295+
if self._stream is None:
296+
raise RedisError("Buffer is closed.")
297+
try:
298+
data = await self._stream.readexactly(length + 2)
299+
except asyncio.IncompleteReadError as error:
300+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
301+
return data[:-2]
302+
303+
async def _readline(self) -> bytes:
304+
"""
305+
read an unknown number of bytes up to the next '\r\n'
306+
line separator, which is discarded.
307+
"""
308+
if self._stream is None:
309+
raise RedisError("Buffer is closed.")
310+
data = await self._stream.readline()
311+
if not data.endswith(b"\r\n"):
312+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
313+
return data[:-2]
314+
411315

412316
class HiredisParser(BaseParser):
413317
"""Parser class for connections using Hiredis"""

tests/test_asyncio/test_connection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,23 @@
1313
from redis.asyncio.retry import Retry
1414
from redis.backoff import NoBackoff
1515
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
16-
from redis.utils import HIREDIS_AVAILABLE
1716
from tests.conftest import skip_if_server_version_lt
1817

1918
from .compat import mock
2019

2120

2221
@pytest.mark.onlynoncluster
23-
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
2422
async def test_invalid_response(create_redis):
2523
r = await create_redis(single_connection_client=True)
2624

2725
raw = b"x"
28-
readline_mock = mock.AsyncMock(return_value=raw)
2926

3027
parser: "PythonParser" = r.connection._parser
31-
with mock.patch.object(parser._buffer, "readline", readline_mock):
28+
if not isinstance(parser, PythonParser):
29+
pytest.skip("PythonParser only")
30+
stream_mock = mock.Mock(parser._stream)
31+
stream_mock.readline.return_value = raw + b"\r\n"
32+
with mock.patch.object(parser, "_stream", stream_mock):
3233
with pytest.raises(InvalidResponse) as cm:
3334
await parser.read_response()
3435
assert str(cm.value) == f"Protocol Error: {raw!r}"

0 commit comments

Comments
 (0)