Skip to content

Commit 210a137

Browse files
authored
bpo-30064: Fix asyncio loop.sock_* race condition issue (#20369)
1 parent 526e23f commit 210a137

File tree

3 files changed

+157
-16
lines changed

3 files changed

+157
-16
lines changed

Lib/asyncio/selector_events.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _add_reader(self, fd, callback, *args):
266266
(handle, writer))
267267
if reader is not None:
268268
reader.cancel()
269+
return handle
269270

270271
def _remove_reader(self, fd):
271272
if self.is_closed():
@@ -302,6 +303,7 @@ def _add_writer(self, fd, callback, *args):
302303
(reader, handle))
303304
if writer is not None:
304305
writer.cancel()
306+
return handle
305307

306308
def _remove_writer(self, fd):
307309
"""Remove a writer callback."""
@@ -329,7 +331,7 @@ def _remove_writer(self, fd):
329331
def add_reader(self, fd, callback, *args):
330332
"""Add a reader callback."""
331333
self._ensure_fd_no_transport(fd)
332-
return self._add_reader(fd, callback, *args)
334+
self._add_reader(fd, callback, *args)
333335

334336
def remove_reader(self, fd):
335337
"""Remove a reader callback."""
@@ -339,7 +341,7 @@ def remove_reader(self, fd):
339341
def add_writer(self, fd, callback, *args):
340342
"""Add a writer callback.."""
341343
self._ensure_fd_no_transport(fd)
342-
return self._add_writer(fd, callback, *args)
344+
self._add_writer(fd, callback, *args)
343345

344346
def remove_writer(self, fd):
345347
"""Remove a writer callback."""
@@ -362,13 +364,15 @@ async def sock_recv(self, sock, n):
362364
pass
363365
fut = self.create_future()
364366
fd = sock.fileno()
365-
self.add_reader(fd, self._sock_recv, fut, sock, n)
367+
self._ensure_fd_no_transport(fd)
368+
handle = self._add_reader(fd, self._sock_recv, fut, sock, n)
366369
fut.add_done_callback(
367-
functools.partial(self._sock_read_done, fd))
370+
functools.partial(self._sock_read_done, fd, handle=handle))
368371
return await fut
369372

370-
def _sock_read_done(self, fd, fut):
371-
self.remove_reader(fd)
373+
def _sock_read_done(self, fd, fut, handle=None):
374+
if handle is None or not handle.cancelled():
375+
self.remove_reader(fd)
372376

373377
def _sock_recv(self, fut, sock, n):
374378
# _sock_recv() can add itself as an I/O callback if the operation can't
@@ -401,9 +405,10 @@ async def sock_recv_into(self, sock, buf):
401405
pass
402406
fut = self.create_future()
403407
fd = sock.fileno()
404-
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
408+
self._ensure_fd_no_transport(fd)
409+
handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf)
405410
fut.add_done_callback(
406-
functools.partial(self._sock_read_done, fd))
411+
functools.partial(self._sock_read_done, fd, handle=handle))
407412
return await fut
408413

409414
def _sock_recv_into(self, fut, sock, buf):
@@ -446,11 +451,12 @@ async def sock_sendall(self, sock, data):
446451

447452
fut = self.create_future()
448453
fd = sock.fileno()
449-
fut.add_done_callback(
450-
functools.partial(self._sock_write_done, fd))
454+
self._ensure_fd_no_transport(fd)
451455
# use a trick with a list in closure to store a mutable state
452-
self.add_writer(fd, self._sock_sendall, fut, sock,
453-
memoryview(data), [n])
456+
handle = self._add_writer(fd, self._sock_sendall, fut, sock,
457+
memoryview(data), [n])
458+
fut.add_done_callback(
459+
functools.partial(self._sock_write_done, fd, handle=handle))
454460
return await fut
455461

456462
def _sock_sendall(self, fut, sock, view, pos):
@@ -502,18 +508,21 @@ def _sock_connect(self, fut, sock, address):
502508
# connection runs in background. We have to wait until the socket
503509
# becomes writable to be notified when the connection succeed or
504510
# fails.
511+
self._ensure_fd_no_transport(fd)
512+
handle = self._add_writer(
513+
fd, self._sock_connect_cb, fut, sock, address)
505514
fut.add_done_callback(
506-
functools.partial(self._sock_write_done, fd))
507-
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
515+
functools.partial(self._sock_write_done, fd, handle=handle))
508516
except (SystemExit, KeyboardInterrupt):
509517
raise
510518
except BaseException as exc:
511519
fut.set_exception(exc)
512520
else:
513521
fut.set_result(None)
514522

515-
def _sock_write_done(self, fd, fut):
516-
self.remove_writer(fd)
523+
def _sock_write_done(self, fd, fut, handle=None):
524+
if handle is None or not handle.cancelled():
525+
self.remove_writer(fd)
517526

518527
def _sock_connect_cb(self, fut, sock, address):
519528
if fut.done():

Lib/test/test_asyncio/test_sock_lowlevel.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import socket
2+
import time
23
import asyncio
34
import sys
45
from asyncio import proactor_events
@@ -122,6 +123,136 @@ def test_sock_client_ops(self):
122123
sock = socket.socket()
123124
self._basetest_sock_recv_into(httpd, sock)
124125

126+
async def _basetest_sock_recv_racing(self, httpd, sock):
127+
sock.setblocking(False)
128+
await self.loop.sock_connect(sock, httpd.address)
129+
130+
task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
131+
await asyncio.sleep(0)
132+
task.cancel()
133+
134+
asyncio.create_task(
135+
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
136+
data = await self.loop.sock_recv(sock, 1024)
137+
# consume data
138+
await self.loop.sock_recv(sock, 1024)
139+
140+
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
141+
142+
async def _basetest_sock_recv_into_racing(self, httpd, sock):
143+
sock.setblocking(False)
144+
await self.loop.sock_connect(sock, httpd.address)
145+
146+
data = bytearray(1024)
147+
with memoryview(data) as buf:
148+
task = asyncio.create_task(
149+
self.loop.sock_recv_into(sock, buf[:1024]))
150+
await asyncio.sleep(0)
151+
task.cancel()
152+
153+
task = asyncio.create_task(
154+
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
155+
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
156+
# consume data
157+
await self.loop.sock_recv_into(sock, buf[nbytes:])
158+
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
159+
160+
await task
161+
162+
async def _basetest_sock_send_racing(self, listener, sock):
163+
listener.bind(('127.0.0.1', 0))
164+
listener.listen(1)
165+
166+
# make connection
167+
sock.setblocking(False)
168+
task = asyncio.create_task(
169+
self.loop.sock_connect(sock, listener.getsockname()))
170+
await asyncio.sleep(0)
171+
server = listener.accept()[0]
172+
server.setblocking(False)
173+
174+
with server:
175+
await task
176+
177+
# fill the buffer
178+
with self.assertRaises(BlockingIOError):
179+
while True:
180+
sock.send(b' ' * 5)
181+
182+
# cancel a blocked sock_sendall
183+
task = asyncio.create_task(
184+
self.loop.sock_sendall(sock, b'hello'))
185+
await asyncio.sleep(0)
186+
task.cancel()
187+
188+
# clear the buffer
189+
async def recv_until():
190+
data = b''
191+
while not data:
192+
data = await self.loop.sock_recv(server, 1024)
193+
data = data.strip()
194+
return data
195+
task = asyncio.create_task(recv_until())
196+
197+
# immediately register another sock_sendall
198+
await self.loop.sock_sendall(sock, b'world')
199+
data = await task
200+
# ProactorEventLoop could deliver hello
201+
self.assertTrue(data.endswith(b'world'))
202+
203+
async def _basetest_sock_connect_racing(self, listener, sock):
204+
listener.bind(('127.0.0.1', 0))
205+
addr = listener.getsockname()
206+
sock.setblocking(False)
207+
208+
task = asyncio.create_task(self.loop.sock_connect(sock, addr))
209+
await asyncio.sleep(0)
210+
task.cancel()
211+
212+
listener.listen(1)
213+
i = 0
214+
while True:
215+
try:
216+
await self.loop.sock_connect(sock, addr)
217+
break
218+
except ConnectionRefusedError: # on Linux we need another retry
219+
await self.loop.sock_connect(sock, addr)
220+
break
221+
except OSError as e: # on Windows we need more retries
222+
# A connect request was made on an already connected socket
223+
if getattr(e, 'winerror', 0) == 10056:
224+
break
225+
226+
# https://stackoverflow.com/a/54437602/3316267
227+
if getattr(e, 'winerror', 0) != 10022:
228+
raise
229+
i += 1
230+
if i >= 128:
231+
raise # too many retries
232+
# avoid touching event loop to maintain race condition
233+
time.sleep(0.01)
234+
235+
def test_sock_client_racing(self):
236+
with test_utils.run_test_server() as httpd:
237+
sock = socket.socket()
238+
with sock:
239+
self.loop.run_until_complete(asyncio.wait_for(
240+
self._basetest_sock_recv_racing(httpd, sock), 10))
241+
sock = socket.socket()
242+
with sock:
243+
self.loop.run_until_complete(asyncio.wait_for(
244+
self._basetest_sock_recv_into_racing(httpd, sock), 10))
245+
listener = socket.socket()
246+
sock = socket.socket()
247+
with listener, sock:
248+
self.loop.run_until_complete(asyncio.wait_for(
249+
self._basetest_sock_send_racing(listener, sock), 10))
250+
listener = socket.socket()
251+
sock = socket.socket()
252+
with listener, sock:
253+
self.loop.run_until_complete(asyncio.wait_for(
254+
self._basetest_sock_connect_racing(listener, sock), 10))
255+
125256
async def _basetest_huge_content(self, address):
126257
sock = socket.socket()
127258
sock.setblocking(False)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix asyncio ``loop.sock_*`` race condition issue

0 commit comments

Comments
 (0)