From fe476e17b43b7ebc9eb898a247c719e10b380177 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 17 Sep 2024 16:40:43 -0700 Subject: [PATCH] PYTHON-4740 Convert asyncio.TimeoutError to socket.timeout for compat --- pymongo/asynchronous/bulk.py | 4 - pymongo/asynchronous/client_bulk.py | 6 +- pymongo/network_layer.py | 110 ++++++++++++++-------------- pymongo/synchronous/bulk.py | 4 - pymongo/synchronous/client_bulk.py | 6 +- 5 files changed, 59 insertions(+), 71 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 9fd673693f..9d33a990ed 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -313,8 +313,6 @@ async def write_command( if isinstance(exc, (NotPrimaryError, OperationFailure)): await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - finally: - bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] async def unack_write( @@ -403,8 +401,6 @@ async def unack_write( assert bwc.start_time is not None bwc._fail(request_id, failure, duration) raise - finally: - bwc.start_time = datetime.datetime.now() return result # type: ignore[return-value] async def _execute_batch_unack( diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 15a0369f41..dc800c9549 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -319,8 +319,6 @@ async def write_command( await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] else: await self.client._process_response({}, bwc.session) # type: ignore[arg-type] - finally: - bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] async def unack_write( @@ -410,9 +408,7 @@ async def unack_write( bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - finally: - bwc.start_time = datetime.datetime.now() - return result # type: ignore[return-value] + return reply async def _execute_batch_unack( self, diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index d99b4fee41..82a6228acc 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -64,65 +64,69 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non loop = asyncio.get_event_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - if sys.platform == "win32": - await asyncio.wait_for(_async_sendall_ssl_windows(sock, buf), timeout=timeout) - else: - await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout) + await asyncio.wait_for(_async_sendall_ssl(sock, buf, loop), timeout=timeout) else: await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc finally: sock.settimeout(timeout) -async def _async_sendall_ssl( - sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop -) -> None: - view = memoryview(buf) - fd = sock.fileno() - sent = 0 - - def _is_ready(fut: Future) -> None: - loop.remove_writer(fd) - loop.remove_reader(fd) - if fut.done(): - return - fut.set_result(None) - - while sent < len(buf): - try: - sent += sock.send(view[sent:]) - except BLOCKING_IO_ERRORS as exc: - fd = sock.fileno() - # Check for closed socket. - if fd == -1: - raise SSLError("Underlying socket has been closed") from None - if isinstance(exc, BLOCKING_IO_READ_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - await fut - if isinstance(exc, BLOCKING_IO_WRITE_ERROR): - fut = loop.create_future() - loop.add_writer(fd, _is_ready, fut) - await fut - if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): - fut = loop.create_future() - loop.add_reader(fd, _is_ready, fut) - loop.add_writer(fd, _is_ready, fut) - await fut - - -# The default Windows asyncio event loop does not support loop.add_reader/add_writer: https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support -async def _async_sendall_ssl_windows(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: - view = memoryview(buf) - total_length = len(buf) - total_sent = 0 - while total_sent < total_length: - try: - sent = sock.send(view[total_sent:]) - except BLOCKING_IO_ERRORS: - await asyncio.sleep(0.5) - sent = 0 - total_sent += sent +if sys.platform != "win32": + + async def _async_sendall_ssl( + sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop + ) -> None: + view = memoryview(buf) + fd = sock.fileno() + sent = 0 + + def _is_ready(fut: Future) -> None: + loop.remove_writer(fd) + loop.remove_reader(fd) + if fut.done(): + return + fut.set_result(None) + + while sent < len(buf): + try: + sent += sock.send(view[sent:]) + except BLOCKING_IO_ERRORS as exc: + fd = sock.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + await fut + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() + loop.add_writer(fd, _is_ready, fut) + await fut + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + loop.add_writer(fd, _is_ready, fut) + await fut +else: + # The default Windows asyncio event loop does not support loop.add_reader/add_writer: + # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support + async def _async_sendall_ssl( + sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop + ) -> None: + view = memoryview(buf) + total_length = len(buf) + total_sent = 0 + while total_sent < total_length: + try: + sent = sock.send(view[total_sent:]) + except BLOCKING_IO_ERRORS: + await asyncio.sleep(0.5) + sent = 0 + total_sent += sent def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 27fcff620c..c658157ea1 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -313,8 +313,6 @@ def write_command( if isinstance(exc, (NotPrimaryError, OperationFailure)): client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise - finally: - bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] def unack_write( @@ -403,8 +401,6 @@ def unack_write( assert bwc.start_time is not None bwc._fail(request_id, failure, duration) raise - finally: - bwc.start_time = datetime.datetime.now() return result # type: ignore[return-value] def _execute_batch_unack( diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 23af231d16..f41f0203f2 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -319,8 +319,6 @@ def write_command( self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] else: self.client._process_response({}, bwc.session) # type: ignore[arg-type] - finally: - bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] def unack_write( @@ -410,9 +408,7 @@ def unack_write( bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} - finally: - bwc.start_time = datetime.datetime.now() - return result # type: ignore[return-value] + return reply def _execute_batch_unack( self,