Skip to content

PYTHON-3186 Avoid SDAM heartbeat timeouts on AWS Lambda #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 30, 2022
Merged
11 changes: 9 additions & 2 deletions pymongo/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def wait_for_read(sock_info, deadline):
# Only Monitor connections can be cancelled.
if context:
sock = sock_info.sock
timed_out = False
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
Expand All @@ -252,15 +253,21 @@ def wait_for_read(sock_info, deadline):
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001)
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
if context.cancelled:
raise _OperationCancelled("hello cancelled")
if readable:
return
if deadline and time.monotonic() > deadline:
if timed_out:
raise socket.timeout("timed out")


Expand Down
23 changes: 20 additions & 3 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from test.version import Version
from typing import Dict, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus

import pymongo
import pymongo.errors
Expand Down Expand Up @@ -279,6 +280,22 @@ def client_options(self):
opts["replicaSet"] = self.replica_set_name
return opts

@property
def uri(self):
"""Return the MongoClient URI for creating a duplicate client."""
opts = client_context.default_client_options.copy()
opts_parts = []
for opt, val in opts.items():
strval = str(val)
if isinstance(val, bool):
strval = strval.lower()
opts_parts.append(f"{opt}={quote_plus(strval)}")
opts_part = "&".join(opts_parts)
auth_part = ""
if client_context.auth_enabled:
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
return f"mongodb://{auth_part}{self.pair}/?{opts_part}"

@property
def hello(self):
if not self._hello:
Expand Down Expand Up @@ -359,7 +376,7 @@ def _init_client(self):
username=db_user,
password=db_pwd,
replicaSet=self.replica_set_name,
**self.default_client_options
**self.default_client_options,
)

# May not have this if OperationFailure was raised earlier.
Expand Down Expand Up @@ -387,7 +404,7 @@ def _init_client(self):
username=db_user,
password=db_pwd,
replicaSet=self.replica_set_name,
**self.default_client_options
**self.default_client_options,
)
else:
self.client = pymongo.MongoClient(
Expand Down Expand Up @@ -490,7 +507,7 @@ def _check_user_provided(self):
username=db_user,
password=db_pwd,
serverSelectionTimeoutMS=100,
**self.default_client_options
**self.default_client_options,
)

try:
Expand Down
85 changes: 85 additions & 0 deletions test/sigstop_sigcont.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Used by test_client.TestClient.test_sigstop_sigcont."""

import logging
import sys

sys.path[0:0] = [""]

from pymongo import monitoring
from pymongo.mongo_client import MongoClient


class HeartbeatLogger(monitoring.ServerHeartbeatListener):
"""Log events until the listener is closed."""

def __init__(self):
self.closed = False

def close(self):
self.closed = True

def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
if self.closed:
return
logging.info("%s", event)

def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
if self.closed:
return
logging.info("%s", event)

def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
if self.closed:
return
logging.warning("%s", event)


def main(uri: str) -> None:
heartbeat_logger = HeartbeatLogger()
client = MongoClient(
uri,
event_listeners=[heartbeat_logger],
heartbeatFrequencyMS=500,
connectTimeoutMS=500,
)
client.admin.command("ping")
logging.info("TEST STARTED")
# test_sigstop_sigcont will SIGSTOP and SIGCONT this process in this loop.
while True:
try:
data = input('Type "q" to quit: ')
except EOFError:
break
if data == "q":
break
client.admin.command("ping")
logging.info("TEST COMPLETED")
heartbeat_logger.close()
client.close()


if __name__ == "__main__":
if len(sys.argv) != 2:
print("unknown or missing options")
print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'")
exit(1)

# Enable logs in this format:
# 2022-03-30 12:40:55,582 INFO <ServerHeartbeatStartedEvent ('localhost', 27017)>
FORMAT = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=FORMAT, level=logging.INFO)
main(sys.argv[1])
34 changes: 34 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import signal
import socket
import struct
import subprocess
import sys
import threading
import time
Expand Down Expand Up @@ -1688,6 +1689,39 @@ def test_srv_max_hosts_kwarg(self):
)
self.assertEqual(len(client.topology_description.server_descriptions()), 2)

@unittest.skipIf(
client_context.load_balancer or client_context.serverless,
"loadBalanced clients do not run SDAM",
)
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
def test_sigstop_sigcont(self):
test_dir = os.path.dirname(os.path.realpath(__file__))
script = os.path.join(test_dir, "sigstop_sigcont.py")
p = subprocess.Popen(
[sys.executable, script, client_context.uri],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
self.addCleanup(p.wait, timeout=1)
self.addCleanup(p.kill)
time.sleep(1)
# Stop the child, sleep for twice the streaming timeout
# (heartbeatFrequencyMS + connectTimeoutMS), and restart.
os.kill(p.pid, signal.SIGSTOP)
time.sleep(2)
os.kill(p.pid, signal.SIGCONT)
time.sleep(0.5)
# Tell the script to exit gracefully.
outs, _ = p.communicate(input=b"q\n", timeout=10)
self.assertTrue(outs)
log_output = outs.decode("utf-8")
self.assertIn("TEST STARTED", log_output)
self.assertIn("ServerHeartbeatStartedEvent", log_output)
self.assertIn("ServerHeartbeatSucceededEvent", log_output)
self.assertIn("TEST COMPLETED", log_output)
self.assertNotIn("ServerHeartbeatFailedEvent", log_output)


class TestExhaustCursor(IntegrationTest):
"""Test that clients properly handle errors from exhaust cursors."""
Expand Down