Skip to content

Commit f014dc3

Browse files
Dev/no can read (#2360)
* make can_read() destructive for simplicity, and rename the method. Remove timeout argument, always timeout immediately. * don't use can_read in pubsub * connection.connect() now has its own retry, don't need it inside a retry loop
1 parent 652ca79 commit f014dc3

File tree

5 files changed

+41
-44
lines changed

5 files changed

+41
-44
lines changed

redis/asyncio/client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
cast,
2525
)
2626

27+
import async_timeout
28+
2729
from redis.asyncio.connection import (
2830
Connection,
2931
ConnectionPool,
@@ -754,15 +756,21 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
754756

755757
await self.check_health()
756758

757-
async def try_read():
758-
if not block:
759-
if not await conn.can_read(timeout=timeout):
759+
if not conn.is_connected:
760+
await conn.connect()
761+
762+
if not block:
763+
764+
async def read_with_timeout():
765+
try:
766+
async with async_timeout.timeout(timeout):
767+
return await conn.read_response()
768+
except asyncio.TimeoutError:
760769
return None
761-
else:
762-
await conn.connect()
763-
return await conn.read_response()
764770

765-
response = await self._execute(conn, try_read)
771+
response = await self._execute(conn, read_with_timeout)
772+
else:
773+
response = await self._execute(conn, conn.read_response)
766774

767775
if conn.health_check_interval and response == self.health_check_response:
768776
# ignore the health check message as user might not expect it

redis/asyncio/connection.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def on_disconnect(self):
208208
def on_connect(self, connection: "Connection"):
209209
raise NotImplementedError()
210210

211-
async def can_read(self, timeout: float) -> bool:
211+
async def can_read_destructive(self) -> bool:
212212
raise NotImplementedError()
213213

214214
async def read_response(
@@ -286,9 +286,9 @@ async def _read_from_socket(
286286
return False
287287
raise ConnectionError(f"Error while reading from socket: {ex.args}")
288288

289-
async def can_read(self, timeout: float) -> bool:
289+
async def can_read_destructive(self) -> bool:
290290
return bool(self.length) or await self._read_from_socket(
291-
timeout=timeout, raise_on_timeout=False
291+
timeout=0, raise_on_timeout=False
292292
)
293293

294294
async def read(self, length: int) -> bytes:
@@ -386,8 +386,8 @@ def on_disconnect(self):
386386
self._buffer = None
387387
self.encoder = None
388388

389-
async def can_read(self, timeout: float):
390-
return self._buffer and bool(await self._buffer.can_read(timeout))
389+
async def can_read_destructive(self):
390+
return self._buffer and bool(await self._buffer.can_read_destructive())
391391

392392
async def read_response(
393393
self, disable_decoding: bool = False
@@ -444,9 +444,7 @@ async def read_response(
444444
class HiredisParser(BaseParser):
445445
"""Parser class for connections using Hiredis"""
446446

447-
__slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")
448-
449-
_next_response: bool
447+
__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
450448

451449
def __init__(self, socket_read_size: int):
452450
if not HIREDIS_AVAILABLE:
@@ -466,23 +464,18 @@ def on_connect(self, connection: "Connection"):
466464
kwargs["errors"] = connection.encoder.encoding_errors
467465

468466
self._reader = hiredis.Reader(**kwargs)
469-
self._next_response = False
470467
self._socket_timeout = connection.socket_timeout
471468

472469
def on_disconnect(self):
473470
self._stream = None
474471
self._reader = None
475-
self._next_response = False
476472

477-
async def can_read(self, timeout: float):
473+
async def can_read_destructive(self):
478474
if not self._stream or not self._reader:
479475
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
480-
481-
if self._next_response is False:
482-
self._next_response = self._reader.gets()
483-
if self._next_response is False:
484-
return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
485-
return True
476+
if self._reader.gets():
477+
return True
478+
return await self.read_from_socket(timeout=0, raise_on_timeout=False)
486479

487480
async def read_from_socket(
488481
self,
@@ -523,12 +516,6 @@ async def read_response(
523516
self.on_disconnect()
524517
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
525518

526-
# _next_response might be cached from a can_read() call
527-
if self._next_response is not False:
528-
response = self._next_response
529-
self._next_response = False
530-
return response
531-
532519
response = self._reader.gets()
533520
while response is False:
534521
await self.read_from_socket()
@@ -925,12 +912,10 @@ async def send_command(self, *args: Any, **kwargs: Any) -> None:
925912
self.pack_command(*args), check_health=kwargs.get("check_health", True)
926913
)
927914

928-
async def can_read(self, timeout: float = 0):
915+
async def can_read_destructive(self):
929916
"""Poll the socket to see if there's data that can be read."""
930-
if not self.is_connected:
931-
await self.connect()
932917
try:
933-
return await self._parser.can_read(timeout)
918+
return await self._parser.can_read_destructive()
934919
except OSError as e:
935920
await self.disconnect(nowait=True)
936921
raise ConnectionError(
@@ -957,6 +942,10 @@ async def read_response(self, disable_decoding: bool = False):
957942
raise ConnectionError(
958943
f"Error while reading from {self.host}:{self.port} : {e.args}"
959944
)
945+
except asyncio.CancelledError:
946+
# need this check for 3.7, where CancelledError
947+
# is subclass of Exception, not BaseException
948+
raise
960949
except Exception:
961950
await self.disconnect(nowait=True)
962951
raise
@@ -1498,12 +1487,12 @@ async def get_connection(self, command_name, *keys, **options):
14981487
# pool before all data has been read or the socket has been
14991488
# closed. either way, reconnect and verify everything is good.
15001489
try:
1501-
if await connection.can_read():
1490+
if await connection.can_read_destructive():
15021491
raise ConnectionError("Connection has data") from None
15031492
except ConnectionError:
15041493
await connection.disconnect()
15051494
await connection.connect()
1506-
if await connection.can_read():
1495+
if await connection.can_read_destructive():
15071496
raise ConnectionError("Connection not ready") from None
15081497
except BaseException:
15091498
# release the connection back to the pool so that we don't
@@ -1699,12 +1688,12 @@ async def get_connection(self, command_name, *keys, **options):
16991688
# pool before all data has been read or the socket has been
17001689
# closed. either way, reconnect and verify everything is good.
17011690
try:
1702-
if await connection.can_read():
1691+
if await connection.can_read_destructive():
17031692
raise ConnectionError("Connection has data") from None
17041693
except ConnectionError:
17051694
await connection.disconnect()
17061695
await connection.connect()
1707-
if await connection.can_read():
1696+
if await connection.can_read_destructive():
17081697
raise ConnectionError("Connection not ready") from None
17091698
except BaseException:
17101699
# release the connection back to the pool so that we don't leak it

tests/test_asyncio/test_cluster.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ async def test_refresh_using_specific_nodes(
433433
Connection,
434434
send_packed_command=mock.DEFAULT,
435435
connect=mock.DEFAULT,
436-
can_read=mock.DEFAULT,
436+
can_read_destructive=mock.DEFAULT,
437437
) as mocks:
438438
# simulate 7006 as a failed node
439439
def execute_command_mock(self, *args, **options):
@@ -473,7 +473,7 @@ def map_7007(self):
473473
execute_command.successful_calls = 0
474474
execute_command.failed_calls = 0
475475
initialize.side_effect = initialize_mock
476-
mocks["can_read"].return_value = False
476+
mocks["can_read_destructive"].return_value = False
477477
mocks["send_packed_command"].return_value = "MOCK_OK"
478478
mocks["connect"].return_value = None
479479
with mock.patch.object(
@@ -514,7 +514,7 @@ async def test_reading_from_replicas_in_round_robin(self) -> None:
514514
send_command=mock.DEFAULT,
515515
read_response=mock.DEFAULT,
516516
_connect=mock.DEFAULT,
517-
can_read=mock.DEFAULT,
517+
can_read_destructive=mock.DEFAULT,
518518
on_connect=mock.DEFAULT,
519519
) as mocks:
520520
with mock.patch.object(
@@ -546,7 +546,7 @@ def execute_command_mock_third(self, *args, **options):
546546
mocks["send_command"].return_value = True
547547
mocks["read_response"].return_value = "OK"
548548
mocks["_connect"].return_value = True
549-
mocks["can_read"].return_value = False
549+
mocks["can_read_destructive"].return_value = False
550550
mocks["on_connect"].return_value = True
551551

552552
# Create a cluster with reading from replications

tests/test_asyncio/test_connection_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def connect(self):
103103
async def disconnect(self):
104104
pass
105105

106-
async def can_read(self, timeout: float = 0):
106+
async def can_read_destructive(self, timeout: float = 0):
107107
return False
108108

109109

tests/test_asyncio/test_pubsub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
847847
self.state = 1
848848
with mock.patch.object(self.pubsub.connection, "_parser") as m:
849849
m.read_response.side_effect = socket.error
850-
m.can_read.side_effect = socket.error
850+
m.can_read_destructive.side_effect = socket.error
851851
# wait until task noticies the disconnect until we
852852
# undo the patch
853853
await self.cond.wait_for(lambda: self.state >= 2)

0 commit comments

Comments
 (0)