Skip to content

Commit 7d7e5d7

Browse files
committed
Wait for a send event, rather than rely on sleep time. Excpect cancel errors.
1 parent 31f3b78 commit 7d7e5d7

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

tests/test_asyncio/test_cwe_404.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import sys
32

43
import pytest
54

@@ -17,23 +16,12 @@ def redis_addr(request):
1716
return host, int(port)
1817

1918

20-
async def pipe(
21-
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
22-
):
23-
while True:
24-
data = await reader.read(1000)
25-
if not data:
26-
break
27-
await asyncio.sleep(delay)
28-
writer.write(data)
29-
await writer.drain()
30-
31-
3219
class DelayProxy:
3320
def __init__(self, addr, redis_addr, delay: float):
3421
self.addr = addr
3522
self.redis_addr = redis_addr
3623
self.delay = delay
24+
self.send_event = asyncio.Event()
3725

3826
async def start(self):
3927
# test that we can connect to redis
@@ -46,10 +34,10 @@ async def start(self):
4634
async def handle(self, reader, writer):
4735
# establish connection to redis
4836
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
49-
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
50-
pipe2 = asyncio.create_task(
51-
pipe(redis_reader, writer, self.delay, "from redis:")
37+
pipe1 = asyncio.create_task(
38+
self.pipe(reader, redis_writer, "to redis:", self.send_event)
5239
)
40+
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:"))
5341
await asyncio.gather(pipe1, pipe2)
5442

5543
async def stop(self):
@@ -58,6 +46,23 @@ async def stop(self):
5846
loop = self.server.get_loop()
5947
await loop.shutdown_asyncgens()
6048

49+
async def pipe(
50+
self,
51+
reader: asyncio.StreamReader,
52+
writer: asyncio.StreamWriter,
53+
name="",
54+
event: asyncio.Event = None,
55+
):
56+
while True:
57+
data = await reader.read(1000)
58+
if not data:
59+
break
60+
if event:
61+
event.set()
62+
await asyncio.sleep(self.delay)
63+
writer.write(data)
64+
await writer.drain()
65+
6166

6267
@pytest.mark.onlynoncluster
6368
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
@@ -75,17 +80,18 @@ async def test_standalone(delay, redis_addr):
7580
await r.set("foo", "foo")
7681
await r.set("bar", "bar")
7782

83+
dp.send_event.clear()
7884
t = asyncio.create_task(r.get("foo"))
79-
await asyncio.sleep(delay)
85+
# Wait until the task has sent, and then some, to make sure it has
86+
# settled on the read.
87+
await dp.send_event.wait()
88+
await asyncio.sleep(0.01) # a little extra time for prudence
8089
t.cancel()
81-
try:
90+
with pytest.raises(asyncio.CancelledError):
8291
await t
83-
sys.stderr.write("try again, we did not cancel the task in time\n")
84-
except asyncio.CancelledError:
85-
sys.stderr.write(
86-
"canceled task, connection is left open with unread response\n"
87-
)
8892

93+
# make sure that our previous request, cancelled while waiting for
94+
# a repsponse, didn't leave the connection open andin a bad state
8995
assert await r.get("bar") == b"bar"
9096
assert await r.ping()
9197
assert await r.get("foo") == b"foo"
@@ -110,10 +116,17 @@ async def test_standalone_pipeline(delay, redis_addr):
110116
pipe2.ping()
111117
pipe2.get("foo")
112118

119+
dp.send_event.clear()
113120
t = asyncio.create_task(pipe.get("foo").execute())
114-
await asyncio.sleep(delay)
121+
# wait until task has settled on the read
122+
await dp.send_event.wait()
123+
await asyncio.sleep(0.01)
115124
t.cancel()
125+
with pytest.raises(asyncio.CancelledError):
126+
await t
116127

128+
# we have now cancelled the pieline in the middle of a request, make sure
129+
# that the connection is still usable
117130
pipe.get("bar")
118131
pipe.ping()
119132
pipe.get("foo")
@@ -144,13 +157,13 @@ async def test_cluster(request, redis_addr):
144157
await r.set("foo", "foo")
145158
await r.set("bar", "bar")
146159

160+
dp.send_event.clear()
147161
t = asyncio.create_task(r.get("foo"))
148-
await asyncio.sleep(0.050)
162+
await dp.send_event.wait()
163+
await asyncio.sleep(0.01)
149164
t.cancel()
150-
try:
165+
with pytest.raises(asyncio.CancelledError):
151166
await t
152-
except asyncio.CancelledError:
153-
pytest.fail("connection is left open with unread response")
154167

155168
assert await r.get("bar") == b"bar"
156169
assert await r.ping()

0 commit comments

Comments
 (0)