From debd47ce701a639dacfbb69c6737b2f63120751c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 13 Dec 2022 17:45:43 +0000 Subject: [PATCH 01/12] PythonParser is now resumable if _stream IO is interrupted --- redis/asyncio/connection.py | 66 ++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4f19153318..780482c58d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,11 +208,14 @@ async def read_response( class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder",) + __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 def on_connect(self, connection: "Connection"): """Called when the stream connects""" @@ -229,6 +232,8 @@ def on_disconnect(self): self.encoder = None async def can_read_destructive(self) -> bool: + if self._buffer: + return True if self._stream is None: raise RedisError("Buffer is closed.") try: @@ -237,7 +242,19 @@ async def can_read_destructive(self) -> bool: except asyncio.TimeoutError: return False - async def read_response( + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._buffer = b"" + self._chunks.clear() + return response + + async def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: if not self._stream or not self.encoder: @@ -282,7 +299,7 @@ async def read_response( if length == -1: return None response = [ - (await self.read_response(disable_decoding)) for _ in range(length) + (await self._read_response(disable_decoding)) for _ in range(length) ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) @@ -293,25 +310,42 @@ async def _read(self, length: int) -> bytes: Read `length` bytes of data. These are assumed to be followed by a '\r\n' terminator which is subsequently discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - data = await self._stream.readexactly(length + 2) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - return data[:-2] + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + if self._stream is None: + raise RedisError("Buffer is closed.") + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result async def _readline(self) -> bytes: """ read an unknown number of bytes up to the next '\r\n' line separator, which is discarded. """ - if self._stream is None: - raise RedisError("Buffer is closed.") - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - return data[:-2] + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + if self._stream is None: + raise RedisError("Buffer is closed.") + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result class HiredisParser(BaseParser): From e60b6328709c35059a8955fba94d761f39bffc8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 13 Dec 2022 22:11:09 +0000 Subject: [PATCH 02/12] Add test for parse resumability --- tests/test_asyncio/test_connection.py | 75 +++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6bf0034146..ec725799d9 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -5,6 +5,7 @@ import pytest +import redis from redis.asyncio.connection import ( Connection, PythonParser, @@ -112,3 +113,77 @@ async def test_connect_timeout_error_without_retry(): await conn.connect() assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" + + +class TestError(BaseException): + pass + + +class InterruptingReader: + """ + A class simulating an asyncio input buffer, but raising a + special exception every other read. + """ + + def __init__(self, data): + self.data = data + self.counter = 0 + self.pos = 0 + + def tick(self): + self.counter += 1 + # return + if (self.counter % 2) == 0: + raise TestError() + + async def read(self, want): + self.tick() + want = 5 + result = self.data[self.pos : self.pos + want] + self.pos += len(result) + return result + + async def readline(self): + self.tick() + find = self.data.find(b"\n", self.pos) + if find >= 0: + result = self.data[self.pos : find + 1] + else: + result = self.data[self.pos :] + self.pos += len(result) + return result + + async def readexactly(self, length): + self.tick() + result = self.data[self.pos : self.pos + length] + if len(result) < length: + raise asyncio.IncompleteReadError(result, None) + self.pos += len(result) + return result + + +async def test_connection_parse_response_resume(r: redis.Redis): + """ + This test verifies that the Connection parser, + be that PythonParser or HiredisParser, + can be interrupted at IO time and then resume parsing. + """ + conn = Connection(**r.connection_pool.connection_kwargs) + await conn.connect() + message = ( + b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" + b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" + ) + + conn._parser._stream = InterruptingReader(message) + for i in range(100): + try: + response = await conn.read_response() + break + except TestError: + pass + + else: + pytest.fail("didn't receive a response") + assert response + assert i > 0 From b9a97207e71fa6026081e3e0353686a3112334c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 13 Dec 2022 23:31:37 +0000 Subject: [PATCH 03/12] Clear PythonParser state when connection or parsing errors occur. --- redis/asyncio/connection.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 780482c58d..ed8208c5d5 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -217,6 +217,10 @@ def __init__(self, socket_read_size: int): self._chunks = [] self._pos = 0 + def _clear(self): + self._buffer = b"" + self._chunks.clear() + def on_connect(self, connection: "Connection"): """Called when the stream connects""" self._stream = connection._reader @@ -247,12 +251,17 @@ async def read_response(self, disable_decoding: bool = False): # augment parsing buffer with previously read data self._buffer += b"".join(self._chunks) self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._buffer = b"" - self._chunks.clear() - return response + try: + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + except (ConnectionError, InvalidResponse): + # We don't want these errors to be resumable + self._clear() + raise + else: + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response async def _read_response( self, disable_decoding: bool = False @@ -275,6 +284,7 @@ async def _read_response( # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): + self._clear() # Successful parse raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() From d3c0cbe70ecd08180a900cfa1ad2d7d9003f0b2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 13 Dec 2022 23:39:59 +0000 Subject: [PATCH 04/12] disable test for cluster mode. --- tests/test_asyncio/test_connection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index ec725799d9..8222cfb21e 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -162,6 +162,7 @@ async def readexactly(self, length): return result +@pytest.mark.onlynoncluster async def test_connection_parse_response_resume(r: redis.Redis): """ This test verifies that the Connection parser, From 9c207ade05840bb9b51623ad8d765c7790c92089 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 14 Dec 2022 00:07:09 +0000 Subject: [PATCH 05/12] Perform "closed" check in a single place. --- redis/asyncio/connection.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ed8208c5d5..aeb4f210ad 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -247,6 +247,8 @@ async def can_read_destructive(self) -> bool: return False async def read_response(self, disable_decoding: bool = False): + if self._stream is None: + raise RedisError("Buffer is closed.") if self._chunks: # augment parsing buffer with previously read data self._buffer += b"".join(self._chunks) @@ -325,8 +327,6 @@ async def _read(self, length: int) -> bytes: if len(self._buffer) >= end: result = self._buffer[self._pos : end - 2] else: - if self._stream is None: - raise RedisError("Buffer is closed.") tail = self._buffer[self._pos :] try: data = await self._stream.readexactly(want - len(tail)) @@ -346,8 +346,6 @@ async def _readline(self) -> bytes: if found >= 0: result = self._buffer[self._pos : found] else: - if self._stream is None: - raise RedisError("Buffer is closed.") tail = self._buffer[self._pos :] data = await self._stream.readline() if not data.endswith(b"\r\n"): From 74440f142f5541dacfb3af535b7c5642cb4efffa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 14 Dec 2022 10:37:33 +0000 Subject: [PATCH 06/12] Update tests --- tests/test_asyncio/test_connection.py | 41 +++++++++++++++------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8222cfb21e..fdbaf7fdbf 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -7,6 +7,7 @@ import redis from redis.asyncio.connection import ( + BaseParser, Connection, PythonParser, UnixDomainSocketConnection, @@ -24,16 +25,19 @@ async def test_invalid_response(create_redis): r = await create_redis(single_connection_client=True) raw = b"x" + fake_stream = FakeStream(raw + b"\r\n") - parser: "PythonParser" = r.connection._parser - if not isinstance(parser, PythonParser): - pytest.skip("PythonParser only") - stream_mock = mock.Mock(parser._stream) - stream_mock.readline.return_value = raw + b"\r\n" - with mock.patch.object(parser, "_stream", stream_mock): + parser: BaseParser = r.connection._parser + with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" + if isinstance(parser, PythonParser): + assert str(cm.value) == f"Protocol Error: {raw!r}" + else: + assert ( + str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' + ) + await r.connection.disconnect() @skip_if_server_version_lt("4.0.0") @@ -115,26 +119,27 @@ async def test_connect_timeout_error_without_retry(): assert str(e.value) == "Timeout connecting to server" -class TestError(BaseException): - pass - - -class InterruptingReader: +class FakeStream: """ A class simulating an asyncio input buffer, but raising a special exception every other read. """ - def __init__(self, data): + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): self.data = data self.counter = 0 self.pos = 0 + self.interrupt_every = interrupt_every def tick(self): self.counter += 1 - # return - if (self.counter % 2) == 0: - raise TestError() + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() async def read(self, want): self.tick() @@ -176,12 +181,12 @@ async def test_connection_parse_response_resume(r: redis.Redis): b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" ) - conn._parser._stream = InterruptingReader(message) + conn._parser._stream = FakeStream(message, interrupt_every=2) for i in range(100): try: response = await conn.read_response() break - except TestError: + except FakeStream.TestError: pass else: From 9b4af7c63a3537317922cb88a28f298cf46dc158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 14 Dec 2022 10:37:58 +0000 Subject: [PATCH 07/12] PythonParser uses a separate method for filling the IO buffer --- redis/asyncio/connection.py | 99 ++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 56 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index aeb4f210ad..6655ec006b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -208,19 +208,14 @@ async def read_response( class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) self.encoder: Optional[Encoder] = None self._buffer = b"" - self._chunks = [] self._pos = 0 - def _clear(self): - self._buffer = b"" - self._chunks.clear() - def on_connect(self, connection: "Connection"): """Called when the stream connects""" self._stream = connection._reader @@ -234,6 +229,7 @@ def on_disconnect(self): if self._stream is not None: self._stream = None self.encoder = None + self._buffer = b"" async def can_read_destructive(self) -> bool: if self._buffer: @@ -247,30 +243,37 @@ async def can_read_destructive(self) -> bool: return False async def read_response(self, disable_decoding: bool = False): - if self._stream is None: - raise RedisError("Buffer is closed.") - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - try: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if not self._buffer: + await self._fill_buffer() + while True: self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - except (ConnectionError, InvalidResponse): - # We don't want these errors to be resumable - self._clear() - raise - else: - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response + try: + response = self._read_response(disable_decoding=disable_decoding) - async def _read_response( + except EOFError: + await self._fill_buffer() + else: + break + # Successfully parsing a response allows us to clear our parsing buffer + self._buffer = self._buffer[self._pos :] + return response + + async def _fill_buffer(self): + """ + IO is performed here + """ + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._buffer += buffer + + def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: - if not self._stream or not self.encoder: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - raw = await self._readline() + raw = self._readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any @@ -286,7 +289,7 @@ async def _read_response( # if the error is a ConnectionError, raise immediately so the user # is notified if isinstance(error, ConnectionError): - self._clear() # Successful parse + self._buffer = self._buffer[self._pos :] # Successful parse raise error # otherwise, we're dealing with a ResponseError that might belong # inside a pipeline response. the connection's read_response() @@ -304,55 +307,39 @@ async def _read_response( length = int(response) if length == -1: return None - response = await self._read(length) + response = self._read(length) # multi-bulk response elif byte == b"*": length = int(response) if length == -1: return None - response = [ - (await self._read_response(disable_decoding)) for _ in range(length) - ] + response = [(self._read_response(disable_decoding)) for _ in range(length)] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response - async def _read(self, length: int) -> bytes: + def _read(self, length: int) -> bytes: """ Read `length` bytes of data. These are assumed to be followed by a '\r\n' terminator which is subsequently discarded. """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want + end = self._pos + length + 2 + if len(self._buffer) < end: + raise EOFError() # Signal that we need more data + result = self._buffer[self._pos : end - 2] + self._pos = end return result - async def _readline(self) -> bytes: + def _readline(self) -> bytes: """ read an unknown number of bytes up to the next '\r\n' line separator, which is discarded. """ found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 + if found < 0: + raise EOFError() # signal that we need more data + result = self._buffer[self._pos : found] + self._pos = found + 2 return result From a5d3ccbda8051be5a24ef03085bba68bc477d632 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 14 Dec 2022 12:46:52 +0000 Subject: [PATCH 08/12] Remove unnecessary check, EOF is detected in the reader. --- redis/asyncio/connection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 6655ec006b..8c1bc774d0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -274,8 +274,6 @@ def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: raw = self._readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any byte, response = raw[:1], raw[1:] From e88c30aa80abdda1b8fc98bdaab80f2fd91a238d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 15 Dec 2022 12:45:10 +0000 Subject: [PATCH 09/12] Make syncronous PythonParser restartable on error, same as HiredisParser --- redis/connection.py | 58 ++++++++++++++++++++-------- tests/test_connection.py | 82 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 17 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index b810fc5714..126ea5db32 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -232,12 +232,6 @@ def read(self, length): self._buffer.seek(self.bytes_read) data = self._buffer.read(length) self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - return data[:-2] def readline(self): @@ -251,23 +245,44 @@ def readline(self): data = buf.readline() self.bytes_read += len(data) + return data[:-2] - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() + def get_pos(self): + """ + Get current read position + """ + return self.bytes_read - return data[:-2] + def rewind(self, pos): + """ + Rewind the buffer to a specific position, to re-start reading + """ + self.bytes_read = pos def purge(self): - self._buffer.seek(0) - self._buffer.truncate() - self.bytes_written = 0 + """ + After a successful read, purge the read part of buffer + """ + unread = self.bytes_written - self.bytes_read + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self.bytes_written = unread self.bytes_read = 0 + self._buffer.seek(0) def close(self): try: - self.purge() + self.bytes_written = self.bytes_read = 0 self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a @@ -315,6 +330,17 @@ def can_read(self, timeout): return self._buffer and self._buffer.can_read(timeout) def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): raw = self._buffer.readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) @@ -355,7 +381,7 @@ def read_response(self, disable_decoding=False): if length == -1: return None response = [ - self.read_response(disable_decoding=disable_decoding) + self._read_response(disable_decoding=disable_decoding) for i in range(length) ] if isinstance(response, bytes) and disable_decoding is False: diff --git a/tests/test_connection.py b/tests/test_connection.py index d9251c31dc..ca1385f129 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,8 +5,9 @@ import pytest +import redis from redis.backoff import NoBackoff -from redis.connection import Connection +from redis.connection import Connection, PythonParser, HiredisParser from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -122,3 +123,82 @@ def test_connect_timeout_error_without_retry(self): assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" self.clear(conn) + + +class FakeSocket: + """ + A class simulating an readable socket, but raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + def recv(self, bufsize): + self.tick() + bufsize = min(5, bufsize) # truncate the read size + result = self.data[self.pos : self.pos + bufsize] + self.pos += len(result) + return result + + def recv_into(self, buffer, nbytes=0, flags=0): + self.tick() + if nbytes == 0: + nbytes = len(buffer) + nbytes = min(5, nbytes) # truncate the read size + result = self.data[self.pos : self.pos + nbytes] + self.pos += len(result) + buffer[: len(result)] = result + return len(result) + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] +) +def test_connection_parse_response_resume(r: redis.Redis, parser_class): + """ + This test verifies that the Connection parser, + be that PythonParser or HiredisParser, + can be interrupted at IO time and then resume parsing. + """ + if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + pytest.skip("Hiredis not available)") + args = dict(r.connection_pool.connection_kwargs) + args["parser_class"] = parser_class + conn = Connection(**args) + conn.connect() + message = ( + b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" + b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" + ) + fake_socket = FakeSocket(message, interrupt_every=2) + + if isinstance(conn._parser, PythonParser): + conn._parser._buffer._sock = fake_socket + else: + conn._parser._sock = fake_socket + for i in range(100): + try: + response = conn.read_response() + break + except FakeSocket.TestError: + pass + + else: + pytest.fail("didn't receive a response") + assert response + assert i > 0 From 989ba5afa0eacb614eda4c216c07e96702c9def6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 15 Dec 2022 13:14:11 +0000 Subject: [PATCH 10/12] add CHANGES --- CHANGES | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES b/CHANGES index 228910f9b8..4436e303f3 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Make PythonParser resumable in case of error (#2512) * Add `timeout=None` in `SentinelConnectionManager.read_response` * Documentation fix: password protected socket connection (#2374) * Allow `timeout=None` in `PubSub.get_message()` to wait forever From 1965e8d2a0e12ca76cdb2483c504dc4fc60cc23f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 15 Dec 2022 21:19:39 +0000 Subject: [PATCH 11/12] isort --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index ca1385f129..7b3d1987d4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,7 +7,7 @@ import redis from redis.backoff import NoBackoff -from redis.connection import Connection, PythonParser, HiredisParser +from redis.connection import Connection, HiredisParser, PythonParser from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE From addd59cb964c4b7f1d379f24b4440fa6cfe65340 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 15 Dec 2022 18:23:12 +0000 Subject: [PATCH 12/12] Move MockStream and MockSocket into their own files --- tests/mocks.py | 41 ++++++++++++++++++++ tests/test_asyncio/mocks.py | 51 +++++++++++++++++++++++++ tests/test_asyncio/test_connection.py | 55 ++------------------------- tests/test_connection.py | 49 +++--------------------- 4 files changed, 101 insertions(+), 95 deletions(-) create mode 100644 tests/mocks.py create mode 100644 tests/test_asyncio/mocks.py diff --git a/tests/mocks.py b/tests/mocks.py new file mode 100644 index 0000000000..d7d450ee86 --- /dev/null +++ b/tests/mocks.py @@ -0,0 +1,41 @@ +# Various mocks for testing + + +class MockSocket: + """ + A class simulating an readable socket, optionally raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + def recv(self, bufsize): + self.tick() + bufsize = min(5, bufsize) # truncate the read size + result = self.data[self.pos : self.pos + bufsize] + self.pos += len(result) + return result + + def recv_into(self, buffer, nbytes=0, flags=0): + self.tick() + if nbytes == 0: + nbytes = len(buffer) + nbytes = min(5, nbytes) # truncate the read size + result = self.data[self.pos : self.pos + nbytes] + self.pos += len(result) + buffer[: len(result)] = result + return len(result) diff --git a/tests/test_asyncio/mocks.py b/tests/test_asyncio/mocks.py new file mode 100644 index 0000000000..89bd9c0ac4 --- /dev/null +++ b/tests/test_asyncio/mocks.py @@ -0,0 +1,51 @@ +import asyncio + +# Helper Mocking classes for the tests. + + +class MockStream: + """ + A class simulating an asyncio input buffer, optionally raising a + special exception every other read. + """ + + class TestError(BaseException): + pass + + def __init__(self, data, interrupt_every=0): + self.data = data + self.counter = 0 + self.pos = 0 + self.interrupt_every = interrupt_every + + def tick(self): + self.counter += 1 + if not self.interrupt_every: + return + if (self.counter % self.interrupt_every) == 0: + raise self.TestError() + + async def read(self, want): + self.tick() + want = 5 + result = self.data[self.pos : self.pos + want] + self.pos += len(result) + return result + + async def readline(self): + self.tick() + find = self.data.find(b"\n", self.pos) + if find >= 0: + result = self.data[self.pos : find + 1] + else: + result = self.data[self.pos :] + self.pos += len(result) + return result + + async def readexactly(self, length): + self.tick() + result = self.data[self.pos : self.pos + length] + if len(result) < length: + raise asyncio.IncompleteReadError(result, None) + self.pos += len(result) + return result diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index fdbaf7fdbf..bf59dbe6b0 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -18,6 +18,7 @@ from tests.conftest import skip_if_server_version_lt from .compat import mock +from .mocks import MockStream @pytest.mark.onlynoncluster @@ -25,7 +26,7 @@ async def test_invalid_response(create_redis): r = await create_redis(single_connection_client=True) raw = b"x" - fake_stream = FakeStream(raw + b"\r\n") + fake_stream = MockStream(raw + b"\r\n") parser: BaseParser = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): @@ -119,54 +120,6 @@ async def test_connect_timeout_error_without_retry(): assert str(e.value) == "Timeout connecting to server" -class FakeStream: - """ - A class simulating an asyncio input buffer, but raising a - special exception every other read. - """ - - class TestError(BaseException): - pass - - def __init__(self, data, interrupt_every=0): - self.data = data - self.counter = 0 - self.pos = 0 - self.interrupt_every = interrupt_every - - def tick(self): - self.counter += 1 - if not self.interrupt_every: - return - if (self.counter % self.interrupt_every) == 0: - raise self.TestError() - - async def read(self, want): - self.tick() - want = 5 - result = self.data[self.pos : self.pos + want] - self.pos += len(result) - return result - - async def readline(self): - self.tick() - find = self.data.find(b"\n", self.pos) - if find >= 0: - result = self.data[self.pos : find + 1] - else: - result = self.data[self.pos :] - self.pos += len(result) - return result - - async def readexactly(self, length): - self.tick() - result = self.data[self.pos : self.pos + length] - if len(result) < length: - raise asyncio.IncompleteReadError(result, None) - self.pos += len(result) - return result - - @pytest.mark.onlynoncluster async def test_connection_parse_response_resume(r: redis.Redis): """ @@ -181,12 +134,12 @@ async def test_connection_parse_response_resume(r: redis.Redis): b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" ) - conn._parser._stream = FakeStream(message, interrupt_every=2) + conn._parser._stream = MockStream(message, interrupt_every=2) for i in range(100): try: response = await conn.read_response() break - except FakeStream.TestError: + except MockStream.TestError: pass else: diff --git a/tests/test_connection.py b/tests/test_connection.py index 7b3d1987d4..e0b53cdf37 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -13,6 +13,7 @@ from redis.utils import HIREDIS_AVAILABLE from .conftest import skip_if_server_version_lt +from .mocks import MockSocket @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -125,46 +126,6 @@ def test_connect_timeout_error_without_retry(self): self.clear(conn) -class FakeSocket: - """ - A class simulating an readable socket, but raising a - special exception every other read. - """ - - class TestError(BaseException): - pass - - def __init__(self, data, interrupt_every=0): - self.data = data - self.counter = 0 - self.pos = 0 - self.interrupt_every = interrupt_every - - def tick(self): - self.counter += 1 - if not self.interrupt_every: - return - if (self.counter % self.interrupt_every) == 0: - raise self.TestError() - - def recv(self, bufsize): - self.tick() - bufsize = min(5, bufsize) # truncate the read size - result = self.data[self.pos : self.pos + bufsize] - self.pos += len(result) - return result - - def recv_into(self, buffer, nbytes=0, flags=0): - self.tick() - if nbytes == 0: - nbytes = len(buffer) - nbytes = min(5, nbytes) # truncate the read size - result = self.data[self.pos : self.pos + nbytes] - self.pos += len(result) - buffer[: len(result)] = result - return len(result) - - @pytest.mark.onlynoncluster @pytest.mark.parametrize( "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] @@ -185,17 +146,17 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n" b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n" ) - fake_socket = FakeSocket(message, interrupt_every=2) + mock_socket = MockSocket(message, interrupt_every=2) if isinstance(conn._parser, PythonParser): - conn._parser._buffer._sock = fake_socket + conn._parser._buffer._sock = mock_socket else: - conn._parser._sock = fake_socket + conn._parser._sock = mock_socket for i in range(100): try: response = conn.read_response() break - except FakeSocket.TestError: + except MockSocket.TestError: pass else: