Skip to content

PYTHON-3186 Avoid SDAM heartbeat timeouts on AWS Lambda #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 30, 2022
Merged
11 changes: 9 additions & 2 deletions pymongo/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def wait_for_read(sock_info, deadline):
# Only Monitor connections can be cancelled.
if context:
sock = sock_info.sock
timed_out = False
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
Expand All @@ -252,15 +253,21 @@ def wait_for_read(sock_info, deadline):
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001)
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
if context.cancelled:
raise _OperationCancelled("hello cancelled")
if readable:
return
if deadline and time.monotonic() > deadline:
if timed_out:
raise socket.timeout("timed out")


Expand Down
95 changes: 95 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import sys
import threading
import time
from multiprocessing import Process, Queue
from typing import Iterable, Type, no_type_check

sys.path[0:0] = [""]
Expand Down Expand Up @@ -1688,6 +1689,100 @@ def test_srv_max_hosts_kwarg(self):
)
self.assertEqual(len(client.topology_description.server_descriptions()), 2)

@staticmethod
def sigstop_sigcont(host: str, opts: dict, event_queue: Queue, message_queue: Queue) -> None:
"""Used by test_sigstop_sigcont."""

class HbListenerProxy(ServerHeartbeatListener):
def __init__(self, queue: Queue):
self.queue = queue
self.closed = False

def handle(self, event):
if self.closed:
return
self.queue.put(event)

def started(self, event):
self.handle(event)

def succeeded(self, event):
self.handle(event)

def failed(self, event):
self.handle(event)

listener = HbListenerProxy(event_queue)
opts["event_listeners"] = [listener]
opts["heartbeatFrequencyMS"] = 500
opts["connectTimeoutMS"] = 500
client = MongoClient(host, **opts)
client.admin.command("ping")
# Wait until we're signalled to exit.
message_queue.get()
# Run one more command to ensure the client is still connected.
client.admin.command("ping")
event_queue.put("DONE")
listener.closed = True
client.close()
message_queue.close()
event_queue.close()
message_queue.join_thread()
event_queue.join_thread()

@unittest.skipIf(
client_context.load_balancer or client_context.serverless,
"loadBalanced clients do not run SDAM",
)
@unittest.skipIf(
sys.platform == "win32",
"multiprocessing does not work with our test suite on Windows due to the issue "
"described in https://bugs.python.org/issue11240",
)
@unittest.skipIf(
is_greenthread_patched(), "multiprocessing does not work with gevent or eventlet"
)
def test_sigstop_sigcont(self):
event_queue = Queue()
message_queue = Queue()
p = Process(
target=self.sigstop_sigcont,
args=(client_context.pair, client_context.client_options, event_queue, message_queue),
)
p.start()
self.addCleanup(p.join, 1)
self.addCleanup(p.terminate)
wait_until(lambda: p.ident is not None, "start subprocess")
pid = p.ident
assert pid is not None
time.sleep(1)
# Stop the child, sleep for twice the streaming timeout
# (heartbeatFrequencyMS + connectTimeoutMS), and restart.
os.kill(pid, signal.SIGSTOP)
time.sleep(3)
os.kill(pid, signal.SIGCONT)
time.sleep(1)
message_queue.put("STOP")
# Ensure there are no heartbeat failures in the child.
events = []
while True:
event = event_queue.get()
if event == "DONE":
break
events.append(event)

p.join()

self.assertTrue(events, "expected to see heartbeat events from child")
self.assertFalse(
[e for e in events if isinstance(e, monitoring.ServerHeartbeatFailedEvent)],
"expected to see zero heartbeat failed events",
)
message_queue.close()
event_queue.close()
message_queue.join_thread()
event_queue.join_thread()


class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
Expand Down