Skip to content

Commit b111cbf

Browse files
PYTHON-4636 - Avoid blocking I/O calls in async code paths (#1870)
Co-authored-by: Shane Harvey <shnhrv@gmail.com>
1 parent 7380097 commit b111cbf

File tree

7 files changed

+248
-166
lines changed

7 files changed

+248
-166
lines changed

pymongo/asynchronous/network.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
"""Internal network layer helper methods."""
1616
from __future__ import annotations
1717

18-
import asyncio
1918
import datetime
20-
import errno
2119
import logging
22-
import socket
2320
import time
2421
from typing import (
2522
TYPE_CHECKING,
@@ -40,19 +37,16 @@
4037
NotPrimaryError,
4138
OperationFailure,
4239
ProtocolError,
43-
_OperationCancelled,
4440
)
4541
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
4642
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
4743
from pymongo.monitoring import _is_speculative_authenticate
4844
from pymongo.network_layer import (
49-
_POLL_TIMEOUT,
5045
_UNPACK_COMPRESSION_HEADER,
5146
_UNPACK_HEADER,
52-
BLOCKING_IO_ERRORS,
47+
async_receive_data,
5348
async_sendall,
5449
)
55-
from pymongo.socket_checker import _errno_from_exception
5650

5751
if TYPE_CHECKING:
5852
from bson import CodecOptions
@@ -318,9 +312,7 @@ async def receive_message(
318312
else:
319313
deadline = None
320314
# Ignore the response's request id.
321-
length, _, response_to, op_code = _UNPACK_HEADER(
322-
await _receive_data_on_socket(conn, 16, deadline)
323-
)
315+
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
324316
# No request_id for exhaust cursor "getMore".
325317
if request_id is not None:
326318
if request_id != response_to:
@@ -336,11 +328,11 @@ async def receive_message(
336328
)
337329
if op_code == 2012:
338330
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
339-
await _receive_data_on_socket(conn, 9, deadline)
331+
await async_receive_data(conn, 9, deadline)
340332
)
341-
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
333+
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
342334
else:
343-
data = await _receive_data_on_socket(conn, length - 16, deadline)
335+
data = await async_receive_data(conn, length - 16, deadline)
344336

345337
try:
346338
unpack_reply = _UNPACK_REPLY[op_code]
@@ -349,66 +341,3 @@ async def receive_message(
349341
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
350342
) from None
351343
return unpack_reply(data)
352-
353-
354-
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
355-
"""Block until at least one byte is read, or a timeout, or a cancel."""
356-
sock = conn.conn
357-
timed_out = False
358-
# Check if the connection's socket has been manually closed
359-
if sock.fileno() == -1:
360-
return
361-
while True:
362-
# SSLSocket can have buffered data which won't be caught by select.
363-
if hasattr(sock, "pending") and sock.pending() > 0:
364-
readable = True
365-
else:
366-
# Wait up to 500ms for the socket to become readable and then
367-
# check for cancellation.
368-
if deadline:
369-
remaining = deadline - time.monotonic()
370-
# When the timeout has expired perform one final check to
371-
# see if the socket is readable. This helps avoid spurious
372-
# timeouts on AWS Lambda and other FaaS environments.
373-
if remaining <= 0:
374-
timed_out = True
375-
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
376-
else:
377-
timeout = _POLL_TIMEOUT
378-
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
379-
if conn.cancel_context.cancelled:
380-
raise _OperationCancelled("operation cancelled")
381-
if readable:
382-
return
383-
if timed_out:
384-
raise socket.timeout("timed out")
385-
await asyncio.sleep(0)
386-
387-
388-
async def _receive_data_on_socket(
389-
conn: AsyncConnection, length: int, deadline: Optional[float]
390-
) -> memoryview:
391-
buf = bytearray(length)
392-
mv = memoryview(buf)
393-
bytes_read = 0
394-
while bytes_read < length:
395-
try:
396-
await wait_for_read(conn, deadline)
397-
# CSOT: Update timeout. When the timeout has expired perform one
398-
# final non-blocking recv. This helps avoid spurious timeouts when
399-
# the response is actually already buffered on the client.
400-
if _csot.get_timeout() and deadline is not None:
401-
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
402-
chunk_length = conn.conn.recv_into(mv[bytes_read:])
403-
except BLOCKING_IO_ERRORS:
404-
raise socket.timeout("timed out") from None
405-
except OSError as exc:
406-
if _errno_from_exception(exc) == errno.EINTR:
407-
continue
408-
raise
409-
if chunk_length == 0:
410-
raise OSError("connection closed")
411-
412-
bytes_read += chunk_length
413-
414-
return mv

0 commit comments

Comments
 (0)