Skip to content

Commit 6cd5173

Browse files
chayimjamestiotiodvora-h
committed
Fixing cancelled async futures (#2666)
Co-authored-by: James R T <jamestiotio@gmail.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
1 parent 7b48b1b commit 6cd5173

File tree

5 files changed

+226
-75
lines changed

5 files changed

+226
-75
lines changed

redis/asyncio/client.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -475,24 +475,32 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
475475
):
476476
raise error
477477

478-
# COMMAND EXECUTION AND PROTOCOL PARSING
479-
async def execute_command(self, *args, **options):
480-
"""Execute a command and return a parsed response"""
481-
await self.initialize()
482-
pool = self.connection_pool
483-
command_name = args[0]
484-
conn = self.connection or await pool.get_connection(command_name, **options)
485-
478+
async def _try_send_command_parse_response(self, conn, *args, **options):
486479
try:
487480
return await conn.retry.call_with_retry(
488481
lambda: self._send_command_parse_response(
489-
conn, command_name, *args, **options
482+
conn, args[0], *args, **options
490483
),
491484
lambda error: self._disconnect_raise(conn, error),
492485
)
486+
except asyncio.CancelledError:
487+
await conn.disconnect(nowait=True)
488+
raise
493489
finally:
494490
if not self.connection:
495-
await pool.release(conn)
491+
await self.connection_pool.release(conn)
492+
493+
# COMMAND EXECUTION AND PROTOCOL PARSING
494+
async def execute_command(self, *args, **options):
495+
"""Execute a command and return a parsed response"""
496+
await self.initialize()
497+
pool = self.connection_pool
498+
command_name = args[0]
499+
conn = self.connection or await pool.get_connection(command_name, **options)
500+
501+
return await asyncio.shield(
502+
self._try_send_command_parse_response(conn, *args, **options)
503+
)
496504

497505
async def parse_response(
498506
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -726,10 +734,18 @@ async def _disconnect_raise_connect(self, conn, error):
726734
is not a TimeoutError. Otherwise, try to reconnect
727735
"""
728736
await conn.disconnect()
737+
729738
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
730739
raise error
731740
await conn.connect()
732741

742+
async def _try_execute(self, conn, command, *arg, **kwargs):
743+
try:
744+
return await command(*arg, **kwargs)
745+
except asyncio.CancelledError:
746+
await conn.disconnect()
747+
raise
748+
733749
async def _execute(self, conn, command, *args, **kwargs):
734750
"""
735751
Connect manually upon disconnection. If the Redis server is down,
@@ -738,9 +754,11 @@ async def _execute(self, conn, command, *args, **kwargs):
738754
called by the # connection to resubscribe us to any channels and
739755
patterns we were previously listening to
740756
"""
741-
return await conn.retry.call_with_retry(
742-
lambda: command(*args, **kwargs),
743-
lambda error: self._disconnect_raise_connect(conn, error),
757+
return await asyncio.shield(
758+
conn.retry.call_with_retry(
759+
lambda: self._try_execute(conn, command, *args, **kwargs),
760+
lambda error: self._disconnect_raise_connect(conn, error),
761+
)
744762
)
745763

746764
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1140,6 +1158,18 @@ async def _disconnect_reset_raise(self, conn, error):
11401158
await self.reset()
11411159
raise
11421160

1161+
async def _try_send_command_parse_response(self, conn, *args, **options):
1162+
try:
1163+
return await conn.retry.call_with_retry(
1164+
lambda: self._send_command_parse_response(
1165+
conn, args[0], *args, **options
1166+
),
1167+
lambda error: self._disconnect_reset_raise(conn, error),
1168+
)
1169+
except asyncio.CancelledError:
1170+
await conn.disconnect()
1171+
raise
1172+
11431173
async def immediate_execute_command(self, *args, **options):
11441174
"""
11451175
Execute a command immediately, but don't auto-retry on a
@@ -1155,13 +1185,13 @@ async def immediate_execute_command(self, *args, **options):
11551185
command_name, self.shard_hint
11561186
)
11571187
self.connection = conn
1158-
1159-
return await conn.retry.call_with_retry(
1160-
lambda: self._send_command_parse_response(
1161-
conn, command_name, *args, **options
1162-
),
1163-
lambda error: self._disconnect_reset_raise(conn, error),
1164-
)
1188+
try:
1189+
return await asyncio.shield(
1190+
self._try_send_command_parse_response(conn, *args, **options)
1191+
)
1192+
except asyncio.CancelledError:
1193+
await conn.disconnect()
1194+
raise
11651195

11661196
def pipeline_execute_command(self, *args, **options):
11671197
"""
@@ -1328,6 +1358,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13281358
await self.reset()
13291359
raise
13301360

1361+
async def _try_execute(self, conn, execute, stack, raise_on_error):
1362+
try:
1363+
return await conn.retry.call_with_retry(
1364+
lambda: execute(conn, stack, raise_on_error),
1365+
lambda error: self._disconnect_raise_reset(conn, error),
1366+
)
1367+
except asyncio.CancelledError:
1368+
# not supposed to be possible, yet here we are
1369+
await conn.disconnect(nowait=True)
1370+
raise
1371+
finally:
1372+
await self.reset()
1373+
13311374
async def execute(self, raise_on_error: bool = True):
13321375
"""Execute all the commands in the current pipeline"""
13331376
stack = self.command_stack
@@ -1350,15 +1393,10 @@ async def execute(self, raise_on_error: bool = True):
13501393

13511394
try:
13521395
return await asyncio.shield(
1353-
conn.retry.call_with_retry(
1354-
lambda: execute(conn, stack, raise_on_error),
1355-
lambda error: self._disconnect_raise_reset(conn, error),
1356-
)
1396+
self._try_execute(conn, execute, stack, raise_on_error)
13571397
)
1358-
except asyncio.CancelledError:
1359-
# not supposed to be possible, yet here we are
1360-
await conn.disconnect(nowait=True)
1361-
raise
1398+
except RuntimeError:
1399+
await self.reset()
13621400
finally:
13631401
await self.reset()
13641402

redis/asyncio/cluster.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,19 @@ async def _parse_and_release(self, connection, *args, **kwargs):
893893
finally:
894894
self._free.append(connection)
895895

896+
async def _try_parse_response(self, cmd, connection, ret):
897+
try:
898+
cmd.result = await asyncio.shield(
899+
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
900+
)
901+
except asyncio.CancelledError:
902+
await connection.disconnect(nowait=True)
903+
raise
904+
except Exception as e:
905+
cmd.result = e
906+
ret = True
907+
return ret
908+
896909
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
897910
# Acquire connection
898911
connection = self.acquire_connection()
@@ -905,13 +918,7 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
905918
# Read responses
906919
ret = False
907920
for cmd in commands:
908-
try:
909-
cmd.result = await self.parse_response(
910-
connection, cmd.args[0], **cmd.kwargs
911-
)
912-
except Exception as e:
913-
cmd.result = e
914-
ret = True
921+
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
915922

916923
# Release connection
917924
self._free.append(connection)

tests/test_asyncio/test_cluster.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -333,23 +333,6 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None:
333333
called_count += 1
334334
assert called_count == 1
335335

336-
async def test_asynckills(self, r) -> None:
337-
338-
await r.set("foo", "foo")
339-
await r.set("bar", "bar")
340-
341-
t = asyncio.create_task(r.get("foo"))
342-
await asyncio.sleep(1)
343-
t.cancel()
344-
try:
345-
await t
346-
except asyncio.CancelledError:
347-
pytest.fail("connection is left open with unread response")
348-
349-
assert await r.get("bar") == b"bar"
350-
assert await r.ping()
351-
assert await r.get("foo") == b"foo"
352-
353336
async def test_execute_command_default_node(self, r: RedisCluster) -> None:
354337
"""
355338
Test command execution without node flag is being executed on the

tests/test_asyncio/test_connection.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,6 @@ async def test_invalid_response(create_redis):
2828
assert str(cm.value) == f"Protocol Error: {raw!r}"
2929

3030

31-
@pytest.mark.onlynoncluster
32-
async def test_asynckills():
33-
from redis.asyncio.client import Redis
34-
35-
for b in [True, False]:
36-
r = Redis(single_connection_client=b)
37-
38-
await r.set("foo", "foo")
39-
await r.set("bar", "bar")
40-
41-
t = asyncio.create_task(r.get("foo"))
42-
await asyncio.sleep(1)
43-
t.cancel()
44-
try:
45-
await t
46-
except asyncio.CancelledError:
47-
pytest.fail("connection left open with unread response")
48-
49-
assert await r.get("bar") == b"bar"
50-
assert await r.ping()
51-
assert await r.get("foo") == b"foo"
52-
53-
5431
@skip_if_server_version_lt("4.0.0")
5532
@pytest.mark.redismod
5633
@pytest.mark.onlynoncluster

tests/test_asyncio/test_cwe_404.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import asyncio
2+
import sys
3+
4+
import pytest
5+
6+
from redis.asyncio import Redis
7+
from redis.asyncio.cluster import RedisCluster
8+
9+
10+
async def pipe(
11+
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
12+
):
13+
while True:
14+
data = await reader.read(1000)
15+
if not data:
16+
break
17+
await asyncio.sleep(delay)
18+
writer.write(data)
19+
await writer.drain()
20+
21+
22+
class DelayProxy:
23+
def __init__(self, addr, redis_addr, delay: float):
24+
self.addr = addr
25+
self.redis_addr = redis_addr
26+
self.delay = delay
27+
28+
async def start(self):
29+
self.server = await asyncio.start_server(self.handle, *self.addr)
30+
self.ROUTINE = asyncio.create_task(self.server.serve_forever())
31+
32+
async def handle(self, reader, writer):
33+
# establish connection to redis
34+
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
35+
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
36+
pipe2 = asyncio.create_task(
37+
pipe(redis_reader, writer, self.delay, "from redis:")
38+
)
39+
await asyncio.gather(pipe1, pipe2)
40+
41+
async def stop(self):
42+
# clean up enough so that we can reuse the looper
43+
self.ROUTINE.cancel()
44+
loop = self.server.get_loop()
45+
await loop.shutdown_asyncgens()
46+
47+
48+
@pytest.mark.onlynoncluster
49+
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
50+
async def test_standalone(delay):
51+
52+
# create a tcp socket proxy that relays data to Redis and back,
53+
# inserting 0.1 seconds of delay
54+
dp = DelayProxy(
55+
addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
56+
)
57+
await dp.start()
58+
59+
for b in [True, False]:
60+
# note that we connect to proxy, rather than to Redis directly
61+
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:
62+
63+
await r.set("foo", "foo")
64+
await r.set("bar", "bar")
65+
66+
t = asyncio.create_task(r.get("foo"))
67+
await asyncio.sleep(delay)
68+
t.cancel()
69+
try:
70+
await t
71+
sys.stderr.write("try again, we did not cancel the task in time\n")
72+
except asyncio.CancelledError:
73+
sys.stderr.write(
74+
"canceled task, connection is left open with unread response\n"
75+
)
76+
77+
assert await r.get("bar") == b"bar"
78+
assert await r.ping()
79+
assert await r.get("foo") == b"foo"
80+
81+
await dp.stop()
82+
83+
84+
@pytest.mark.onlynoncluster
85+
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
86+
async def test_standalone_pipeline(delay):
87+
dp = DelayProxy(
88+
addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2
89+
)
90+
await dp.start()
91+
async with Redis(host="localhost", port=5380) as r:
92+
await r.set("foo", "foo")
93+
await r.set("bar", "bar")
94+
95+
pipe = r.pipeline()
96+
97+
pipe2 = r.pipeline()
98+
pipe2.get("bar")
99+
pipe2.ping()
100+
pipe2.get("foo")
101+
102+
t = asyncio.create_task(pipe.get("foo").execute())
103+
await asyncio.sleep(delay)
104+
t.cancel()
105+
106+
pipe.get("bar")
107+
pipe.ping()
108+
pipe.get("foo")
109+
pipe.reset()
110+
111+
assert await pipe.execute() is None
112+
113+
# validating that the pipeline can be used as it could previously
114+
pipe.get("bar")
115+
pipe.ping()
116+
pipe.get("foo")
117+
assert await pipe.execute() == [b"bar", True, b"foo"]
118+
assert await pipe2.execute() == [b"bar", True, b"foo"]
119+
120+
await dp.stop()
121+
122+
123+
@pytest.mark.onlycluster
124+
async def test_cluster(request):
125+
126+
dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1)
127+
await dp.start()
128+
129+
r = RedisCluster.from_url("redis://localhost:5381")
130+
await r.initialize()
131+
await r.set("foo", "foo")
132+
await r.set("bar", "bar")
133+
134+
t = asyncio.create_task(r.get("foo"))
135+
await asyncio.sleep(0.050)
136+
t.cancel()
137+
try:
138+
await t
139+
except asyncio.CancelledError:
140+
pytest.fail("connection is left open with unread response")
141+
142+
assert await r.get("bar") == b"bar"
143+
assert await r.ping()
144+
assert await r.get("foo") == b"foo"
145+
146+
await dp.stop()

0 commit comments

Comments
 (0)