Skip to content

Commit c3ad9aa

Browse files
ShaneHarveyjuliusgeo
authored andcommitted
PYTHON-3186 Avoid SDAM heartbeat timeouts on AWS Lambda (mongodb#912)
Poll monitor socket with timeout=0 one last time after timeout expires. This avoids heartbeat timeouts and connection churn on Lambda and other FaaS envs.
1 parent fb5ca18 commit c3ad9aa

File tree

4 files changed

+148
-5
lines changed

4 files changed

+148
-5
lines changed

pymongo/network.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def wait_for_read(sock_info, deadline):
244244
# Only Monitor connections can be cancelled.
245245
if context:
246246
sock = sock_info.sock
247+
timed_out = False
247248
while True:
248249
# SSLSocket can have buffered data which won't be caught by select.
249250
if hasattr(sock, "pending") and sock.pending() > 0:
@@ -252,15 +253,21 @@ def wait_for_read(sock_info, deadline):
252253
# Wait up to 500ms for the socket to become readable and then
253254
# check for cancellation.
254255
if deadline:
255-
timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001)
256+
remaining = deadline - time.monotonic()
257+
# When the timeout has expired perform one final check to
258+
# see if the socket is readable. This helps avoid spurious
259+
# timeouts on AWS Lambda and other FaaS environments.
260+
if remaining <= 0:
261+
timed_out = True
262+
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
256263
else:
257264
timeout = _POLL_TIMEOUT
258265
readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout)
259266
if context.cancelled:
260267
raise _OperationCancelled("hello cancelled")
261268
if readable:
262269
return
263-
if deadline and time.monotonic() > deadline:
270+
if timed_out:
264271
raise socket.timeout("timed out")
265272

266273

test/__init__.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from test.version import Version
4545
from typing import Dict, no_type_check
4646
from unittest import SkipTest
47+
from urllib.parse import quote_plus
4748

4849
import pymongo
4950
import pymongo.errors
@@ -279,6 +280,22 @@ def client_options(self):
279280
opts["replicaSet"] = self.replica_set_name
280281
return opts
281282

283+
@property
284+
def uri(self):
285+
"""Return the MongoClient URI for creating a duplicate client."""
286+
opts = client_context.default_client_options.copy()
287+
opts_parts = []
288+
for opt, val in opts.items():
289+
strval = str(val)
290+
if isinstance(val, bool):
291+
strval = strval.lower()
292+
opts_parts.append(f"{opt}={quote_plus(strval)}")
293+
opts_part = "&".join(opts_parts)
294+
auth_part = ""
295+
if client_context.auth_enabled:
296+
auth_part = f"{quote_plus(db_user)}:{quote_plus(db_pwd)}@"
297+
return f"mongodb://{auth_part}{self.pair}/?{opts_part}"
298+
282299
@property
283300
def hello(self):
284301
if not self._hello:
@@ -359,7 +376,7 @@ def _init_client(self):
359376
username=db_user,
360377
password=db_pwd,
361378
replicaSet=self.replica_set_name,
362-
**self.default_client_options
379+
**self.default_client_options,
363380
)
364381

365382
# May not have this if OperationFailure was raised earlier.
@@ -387,7 +404,7 @@ def _init_client(self):
387404
username=db_user,
388405
password=db_pwd,
389406
replicaSet=self.replica_set_name,
390-
**self.default_client_options
407+
**self.default_client_options,
391408
)
392409
else:
393410
self.client = pymongo.MongoClient(
@@ -490,7 +507,7 @@ def _check_user_provided(self):
490507
username=db_user,
491508
password=db_pwd,
492509
serverSelectionTimeoutMS=100,
493-
**self.default_client_options
510+
**self.default_client_options,
494511
)
495512

496513
try:

test/sigstop_sigcont.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2022-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Used by test_client.TestClient.test_sigstop_sigcont."""
16+
17+
import logging
18+
import sys
19+
20+
sys.path[0:0] = [""]
21+
22+
from pymongo import monitoring
23+
from pymongo.mongo_client import MongoClient
24+
25+
26+
class HeartbeatLogger(monitoring.ServerHeartbeatListener):
27+
"""Log events until the listener is closed."""
28+
29+
def __init__(self):
30+
self.closed = False
31+
32+
def close(self):
33+
self.closed = True
34+
35+
def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None:
36+
if self.closed:
37+
return
38+
logging.info("%s", event)
39+
40+
def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None:
41+
if self.closed:
42+
return
43+
logging.info("%s", event)
44+
45+
def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None:
46+
if self.closed:
47+
return
48+
logging.warning("%s", event)
49+
50+
51+
def main(uri: str) -> None:
52+
heartbeat_logger = HeartbeatLogger()
53+
client = MongoClient(
54+
uri,
55+
event_listeners=[heartbeat_logger],
56+
heartbeatFrequencyMS=500,
57+
connectTimeoutMS=500,
58+
)
59+
client.admin.command("ping")
60+
logging.info("TEST STARTED")
61+
# test_sigstop_sigcont will SIGSTOP and SIGCONT this process in this loop.
62+
while True:
63+
try:
64+
data = input('Type "q" to quit: ')
65+
except EOFError:
66+
break
67+
if data == "q":
68+
break
69+
client.admin.command("ping")
70+
logging.info("TEST COMPLETED")
71+
heartbeat_logger.close()
72+
client.close()
73+
74+
75+
if __name__ == "__main__":
76+
if len(sys.argv) != 2:
77+
print("unknown or missing options")
78+
print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'")
79+
exit(1)
80+
81+
# Enable logs in this format:
82+
# 2022-03-30 12:40:55,582 INFO <ServerHeartbeatStartedEvent ('localhost', 27017)>
83+
FORMAT = "%(asctime)s %(levelname)s %(message)s"
84+
logging.basicConfig(format=FORMAT, level=logging.INFO)
85+
main(sys.argv[1])

test/test_client.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import signal
2424
import socket
2525
import struct
26+
import subprocess
2627
import sys
2728
import threading
2829
import time
@@ -1688,6 +1689,39 @@ def test_srv_max_hosts_kwarg(self):
16881689
)
16891690
self.assertEqual(len(client.topology_description.server_descriptions()), 2)
16901691

1692+
@unittest.skipIf(
1693+
client_context.load_balancer or client_context.serverless,
1694+
"loadBalanced clients do not run SDAM",
1695+
)
1696+
@unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP")
1697+
def test_sigstop_sigcont(self):
1698+
test_dir = os.path.dirname(os.path.realpath(__file__))
1699+
script = os.path.join(test_dir, "sigstop_sigcont.py")
1700+
p = subprocess.Popen(
1701+
[sys.executable, script, client_context.uri],
1702+
stdin=subprocess.PIPE,
1703+
stdout=subprocess.PIPE,
1704+
stderr=subprocess.STDOUT,
1705+
)
1706+
self.addCleanup(p.wait, timeout=1)
1707+
self.addCleanup(p.kill)
1708+
time.sleep(1)
1709+
# Stop the child, sleep for twice the streaming timeout
1710+
# (heartbeatFrequencyMS + connectTimeoutMS), and restart.
1711+
os.kill(p.pid, signal.SIGSTOP)
1712+
time.sleep(2)
1713+
os.kill(p.pid, signal.SIGCONT)
1714+
time.sleep(0.5)
1715+
# Tell the script to exit gracefully.
1716+
outs, _ = p.communicate(input=b"q\n", timeout=10)
1717+
self.assertTrue(outs)
1718+
log_output = outs.decode("utf-8")
1719+
self.assertIn("TEST STARTED", log_output)
1720+
self.assertIn("ServerHeartbeatStartedEvent", log_output)
1721+
self.assertIn("ServerHeartbeatSucceededEvent", log_output)
1722+
self.assertIn("TEST COMPLETED", log_output)
1723+
self.assertNotIn("ServerHeartbeatFailedEvent", log_output)
1724+
16911725

16921726
class TestExhaustCursor(IntegrationTest):
16931727
"""Test that clients properly handle errors from exhaust cursors."""

0 commit comments

Comments
 (0)