Skip to content

Commit adb1670

Browse files
committed
PYTHON-4636 Stop blocking the I/O Loop for socket reads
1 parent 9df635f commit adb1670

File tree

4 files changed

+125
-149
lines changed

4 files changed

+125
-149
lines changed

pymongo/asynchronous/network.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
"""Internal network layer helper methods."""
1616
from __future__ import annotations
1717

18-
import asyncio
1918
import datetime
20-
import errno
2119
import logging
22-
import socket
2320
import time
2421
from typing import (
2522
TYPE_CHECKING,
@@ -40,19 +37,16 @@
4037
NotPrimaryError,
4138
OperationFailure,
4239
ProtocolError,
43-
_OperationCancelled,
4440
)
4541
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
4642
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
4743
from pymongo.monitoring import _is_speculative_authenticate
4844
from pymongo.network_layer import (
49-
_POLL_TIMEOUT,
5045
_UNPACK_COMPRESSION_HEADER,
5146
_UNPACK_HEADER,
52-
BLOCKING_IO_ERRORS,
47+
async_receive_data,
5348
async_sendall,
5449
)
55-
from pymongo.socket_checker import _errno_from_exception
5650

5751
if TYPE_CHECKING:
5852
from bson import CodecOptions
@@ -318,9 +312,7 @@ async def receive_message(
318312
else:
319313
deadline = None
320314
# Ignore the response's request id.
321-
length, _, response_to, op_code = _UNPACK_HEADER(
322-
await _receive_data_on_socket(conn, 16, deadline)
323-
)
315+
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
324316
# No request_id for exhaust cursor "getMore".
325317
if request_id is not None:
326318
if request_id != response_to:
@@ -336,11 +328,11 @@ async def receive_message(
336328
)
337329
if op_code == 2012:
338330
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
339-
await _receive_data_on_socket(conn, 9, deadline)
331+
await async_receive_data(conn, 9, deadline)
340332
)
341-
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
333+
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
342334
else:
343-
data = await _receive_data_on_socket(conn, length - 16, deadline)
335+
data = await async_receive_data(conn, length - 16, deadline)
344336

345337
try:
346338
unpack_reply = _UNPACK_REPLY[op_code]
@@ -349,66 +341,3 @@ async def receive_message(
349341
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
350342
) from None
351343
return unpack_reply(data)
352-
353-
354-
async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
355-
"""Block until at least one byte is read, or a timeout, or a cancel."""
356-
sock = conn.conn
357-
timed_out = False
358-
# Check if the connection's socket has been manually closed
359-
if sock.fileno() == -1:
360-
return
361-
while True:
362-
# SSLSocket can have buffered data which won't be caught by select.
363-
if hasattr(sock, "pending") and sock.pending() > 0:
364-
readable = True
365-
else:
366-
# Wait up to 500ms for the socket to become readable and then
367-
# check for cancellation.
368-
if deadline:
369-
remaining = deadline - time.monotonic()
370-
# When the timeout has expired perform one final check to
371-
# see if the socket is readable. This helps avoid spurious
372-
# timeouts on AWS Lambda and other FaaS environments.
373-
if remaining <= 0:
374-
timed_out = True
375-
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
376-
else:
377-
timeout = _POLL_TIMEOUT
378-
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
379-
if conn.cancel_context.cancelled:
380-
raise _OperationCancelled("operation cancelled")
381-
if readable:
382-
return
383-
if timed_out:
384-
raise socket.timeout("timed out")
385-
await asyncio.sleep(0)
386-
387-
388-
async def _receive_data_on_socket(
389-
conn: AsyncConnection, length: int, deadline: Optional[float]
390-
) -> memoryview:
391-
buf = bytearray(length)
392-
mv = memoryview(buf)
393-
bytes_read = 0
394-
while bytes_read < length:
395-
try:
396-
await wait_for_read(conn, deadline)
397-
# CSOT: Update timeout. When the timeout has expired perform one
398-
# final non-blocking recv. This helps avoid spurious timeouts when
399-
# the response is actually already buffered on the client.
400-
if _csot.get_timeout() and deadline is not None:
401-
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
402-
chunk_length = conn.conn.recv_into(mv[bytes_read:])
403-
except BLOCKING_IO_ERRORS:
404-
raise socket.timeout("timed out") from None
405-
except OSError as exc:
406-
if _errno_from_exception(exc) == errno.EINTR:
407-
continue
408-
raise
409-
if chunk_length == 0:
410-
raise OSError("connection closed")
411-
412-
bytes_read += chunk_length
413-
414-
return mv

pymongo/network_layer.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19+
import errno
1920
import socket
2021
import struct
2122
import sys
23+
import time
2224
from asyncio import AbstractEventLoop, Future
2325
from typing import (
26+
TYPE_CHECKING,
27+
Optional,
2428
Union,
2529
)
2630

27-
from pymongo import ssl_support
31+
from pymongo import _csot, ssl_support
32+
from pymongo.errors import _OperationCancelled
33+
from pymongo.socket_checker import _errno_from_exception
2834

2935
try:
3036
from ssl import SSLError, SSLSocket
@@ -51,6 +57,10 @@
5157
BLOCKING_IO_WRITE_ERROR,
5258
)
5359

60+
if TYPE_CHECKING:
61+
from pymongo.asynchronous.pool import AsyncConnection
62+
from pymongo.synchronous.pool import Connection
63+
5464
_UNPACK_HEADER = struct.Struct("<iiii").unpack
5565
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
5666
_POLL_TIMEOUT = 0.5
@@ -131,3 +141,106 @@ async def _async_sendall_ssl(
131141

132142
def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
133143
sock.sendall(buf)
144+
145+
146+
async def async_receive_data(
147+
conn: AsyncConnection, length: int, deadline: Optional[float]
148+
) -> memoryview:
149+
sock = conn.conn
150+
sock_timeout = sock.gettimeout()
151+
if deadline:
152+
# When the timeout has expired perform one final check to
153+
# see if the socket is readable. This helps avoid spurious
154+
# timeouts on AWS Lambda and other FaaS environments.
155+
timeout = max(deadline - time.monotonic(), 0)
156+
else:
157+
timeout = sock_timeout
158+
159+
sock.settimeout(0.0)
160+
loop = asyncio.get_event_loop()
161+
try:
162+
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
163+
return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout)
164+
else:
165+
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
166+
except asyncio.TimeoutError as exc:
167+
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
168+
raise socket.timeout("timed out") from exc
169+
finally:
170+
sock.settimeout(sock_timeout)
171+
172+
173+
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
174+
mv = memoryview(bytearray(length))
175+
bytes_read = 0
176+
while bytes_read < length:
177+
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
178+
if chunk_length == 0:
179+
raise OSError("connection closed")
180+
bytes_read += chunk_length
181+
return mv
182+
183+
184+
async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001
185+
return memoryview(b"")
186+
187+
188+
# Sync version:
189+
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
190+
"""Block until at least one byte is read, or a timeout, or a cancel."""
191+
sock = conn.conn
192+
timed_out = False
193+
# Check if the connection's socket has been manually closed
194+
if sock.fileno() == -1:
195+
return
196+
while True:
197+
# SSLSocket can have buffered data which won't be caught by select.
198+
if hasattr(sock, "pending") and sock.pending() > 0:
199+
readable = True
200+
else:
201+
# Wait up to 500ms for the socket to become readable and then
202+
# check for cancellation.
203+
if deadline:
204+
remaining = deadline - time.monotonic()
205+
# When the timeout has expired perform one final check to
206+
# see if the socket is readable. This helps avoid spurious
207+
# timeouts on AWS Lambda and other FaaS environments.
208+
if remaining <= 0:
209+
timed_out = True
210+
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
211+
else:
212+
timeout = _POLL_TIMEOUT
213+
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
214+
if conn.cancel_context.cancelled:
215+
raise _OperationCancelled("operation cancelled")
216+
if readable:
217+
return
218+
if timed_out:
219+
raise socket.timeout("timed out")
220+
221+
222+
def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
223+
buf = bytearray(length)
224+
mv = memoryview(buf)
225+
bytes_read = 0
226+
while bytes_read < length:
227+
try:
228+
wait_for_read(conn, deadline)
229+
# CSOT: Update timeout. When the timeout has expired perform one
230+
# final non-blocking recv. This helps avoid spurious timeouts when
231+
# the response is actually already buffered on the client.
232+
if _csot.get_timeout() and deadline is not None:
233+
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
234+
chunk_length = conn.conn.recv_into(mv[bytes_read:])
235+
except BLOCKING_IO_ERRORS:
236+
raise socket.timeout("timed out") from None
237+
except OSError as exc:
238+
if _errno_from_exception(exc) == errno.EINTR:
239+
continue
240+
raise
241+
if chunk_length == 0:
242+
raise OSError("connection closed")
243+
244+
bytes_read += chunk_length
245+
246+
return mv

pymongo/synchronous/network.py

Lines changed: 5 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
from __future__ import annotations
1717

1818
import datetime
19-
import errno
2019
import logging
21-
import socket
2220
import time
2321
from typing import (
2422
TYPE_CHECKING,
@@ -39,19 +37,16 @@
3937
NotPrimaryError,
4038
OperationFailure,
4139
ProtocolError,
42-
_OperationCancelled,
4340
)
4441
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
4542
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
4643
from pymongo.monitoring import _is_speculative_authenticate
4744
from pymongo.network_layer import (
48-
_POLL_TIMEOUT,
4945
_UNPACK_COMPRESSION_HEADER,
5046
_UNPACK_HEADER,
51-
BLOCKING_IO_ERRORS,
47+
receive_data,
5248
sendall,
5349
)
54-
from pymongo.socket_checker import _errno_from_exception
5550

5651
if TYPE_CHECKING:
5752
from bson import CodecOptions
@@ -317,7 +312,7 @@ def receive_message(
317312
else:
318313
deadline = None
319314
# Ignore the response's request id.
320-
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
315+
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
321316
# No request_id for exhaust cursor "getMore".
322317
if request_id is not None:
323318
if request_id != response_to:
@@ -332,12 +327,10 @@ def receive_message(
332327
f"message size ({max_message_size!r})"
333328
)
334329
if op_code == 2012:
335-
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
336-
_receive_data_on_socket(conn, 9, deadline)
337-
)
338-
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
330+
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
331+
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
339332
else:
340-
data = _receive_data_on_socket(conn, length - 16, deadline)
333+
data = receive_data(conn, length - 16, deadline)
341334

342335
try:
343336
unpack_reply = _UNPACK_REPLY[op_code]
@@ -346,63 +339,3 @@ def receive_message(
346339
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
347340
) from None
348341
return unpack_reply(data)
349-
350-
351-
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
352-
"""Block until at least one byte is read, or a timeout, or a cancel."""
353-
sock = conn.conn
354-
timed_out = False
355-
# Check if the connection's socket has been manually closed
356-
if sock.fileno() == -1:
357-
return
358-
while True:
359-
# SSLSocket can have buffered data which won't be caught by select.
360-
if hasattr(sock, "pending") and sock.pending() > 0:
361-
readable = True
362-
else:
363-
# Wait up to 500ms for the socket to become readable and then
364-
# check for cancellation.
365-
if deadline:
366-
remaining = deadline - time.monotonic()
367-
# When the timeout has expired perform one final check to
368-
# see if the socket is readable. This helps avoid spurious
369-
# timeouts on AWS Lambda and other FaaS environments.
370-
if remaining <= 0:
371-
timed_out = True
372-
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
373-
else:
374-
timeout = _POLL_TIMEOUT
375-
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
376-
if conn.cancel_context.cancelled:
377-
raise _OperationCancelled("operation cancelled")
378-
if readable:
379-
return
380-
if timed_out:
381-
raise socket.timeout("timed out")
382-
383-
384-
def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
385-
buf = bytearray(length)
386-
mv = memoryview(buf)
387-
bytes_read = 0
388-
while bytes_read < length:
389-
try:
390-
wait_for_read(conn, deadline)
391-
# CSOT: Update timeout. When the timeout has expired perform one
392-
# final non-blocking recv. This helps avoid spurious timeouts when
393-
# the response is actually already buffered on the client.
394-
if _csot.get_timeout() and deadline is not None:
395-
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
396-
chunk_length = conn.conn.recv_into(mv[bytes_read:])
397-
except BLOCKING_IO_ERRORS:
398-
raise socket.timeout("timed out") from None
399-
except OSError as exc:
400-
if _errno_from_exception(exc) == errno.EINTR:
401-
continue
402-
raise
403-
if chunk_length == 0:
404-
raise OSError("connection closed")
405-
406-
bytes_read += chunk_length
407-
408-
return mv

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"AsyncConnection": "Connection",
4444
"async_command": "command",
4545
"async_receive_message": "receive_message",
46+
"async_receive_data": "receive_data",
4647
"async_sendall": "sendall",
4748
"asynchronous": "synchronous",
4849
"Asynchronous": "Synchronous",

0 commit comments

Comments
 (0)