Skip to content

Commit ff2f959

Browse files
authored
PYTHON-2560 Retry KMS requests on transient errors (#2024)
1 parent ce1c49a commit ff2f959

File tree

4 files changed

+260
-40
lines changed

4 files changed

+260
-40
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import contextlib
2020
import enum
2121
import socket
22+
import time as time # noqa: PLC0414 # needed in sync version
2223
import uuid
2324
import weakref
2425
from copy import deepcopy
@@ -63,7 +64,11 @@
6364
from pymongo.asynchronous.cursor import AsyncCursor
6465
from pymongo.asynchronous.database import AsyncDatabase
6566
from pymongo.asynchronous.mongo_client import AsyncMongoClient
66-
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
67+
from pymongo.asynchronous.pool import (
68+
_configured_socket,
69+
_get_timeout_details,
70+
_raise_connection_failure,
71+
)
6772
from pymongo.common import CONNECT_TIMEOUT
6873
from pymongo.daemon import _spawn_daemon
6974
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@@ -72,7 +77,7 @@
7277
EncryptedCollectionError,
7378
EncryptionError,
7479
InvalidOperation,
75-
PyMongoError,
80+
NetworkTimeout,
7681
ServerSelectionTimeoutError,
7782
)
7883
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
@@ -88,6 +93,9 @@
8893
if TYPE_CHECKING:
8994
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9095

96+
from pymongo.pyopenssl_context import _sslConn
97+
from pymongo.typings import _Address
98+
9199

92100
_IS_SYNC = False
93101

@@ -103,6 +111,13 @@
103111
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
104112

105113

114+
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
115+
try:
116+
return await _configured_socket(address, opts)
117+
except Exception as exc:
118+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119+
120+
106121
@contextlib.contextmanager
107122
def _wrap_encryption_errors() -> Iterator[None]:
108123
"""Context manager to wrap encryption related errors."""
@@ -166,18 +181,22 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
166181
None, # crlfile
167182
False, # allow_invalid_certificates
168183
False, # allow_invalid_hostnames
169-
False,
170-
) # disable_ocsp_endpoint_check
184+
False, # disable_ocsp_endpoint_check
185+
)
171186
# CSOT: set timeout for socket creation.
172187
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
173188
opts = PoolOptions(
174189
connect_timeout=connect_timeout,
175190
socket_timeout=connect_timeout,
176191
ssl_context=ctx,
177192
)
178-
host, port = parse_host(endpoint, _HTTPS_PORT)
193+
address = parse_host(endpoint, _HTTPS_PORT)
194+
sleep_u = kms_context.usleep
195+
if sleep_u:
196+
sleep_sec = float(sleep_u) / 1e6
197+
await asyncio.sleep(sleep_sec)
179198
try:
180-
conn = await _configured_socket((host, port), opts)
199+
conn = await _connect_kms(address, opts)
181200
try:
182201
await async_sendall(conn, message)
183202
while kms_context.bytes_needed > 0:
@@ -194,20 +213,29 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194213
if not data:
195214
raise OSError("KMS connection closed")
196215
kms_context.feed(data)
197-
# Async raises an OSError instead of returning empty bytes
198-
except OSError as err:
199-
raise OSError("KMS connection closed") from err
200-
except BLOCKING_IO_ERRORS:
201-
raise socket.timeout("timed out") from None
216+
except MongoCryptError:
217+
raise # Propagate MongoCryptError errors directly.
218+
except Exception as exc:
219+
# Wrap I/O errors in PyMongo exceptions.
220+
if isinstance(exc, BLOCKING_IO_ERRORS):
221+
exc = socket.timeout("timed out")
222+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
202223
finally:
203224
conn.close()
204-
except (PyMongoError, MongoCryptError):
205-
raise # Propagate pymongo errors directly.
206-
except asyncio.CancelledError:
207-
raise
208-
except Exception as error:
209-
# Wrap I/O errors in PyMongo exceptions.
210-
_raise_connection_failure((host, port), error)
225+
except MongoCryptError:
226+
raise # Propagate MongoCryptError errors directly.
227+
except Exception as exc:
228+
remaining = _csot.remaining()
229+
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
230+
raise
231+
# Mark this attempt as failed and defer to libmongocrypt to retry.
232+
try:
233+
kms_context.fail()
234+
except MongoCryptError as final_err:
235+
exc = MongoCryptError(
236+
f"{final_err}, last attempt failed with: {exc}", final_err.code
237+
)
238+
raise exc from final_err
211239

212240
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
213241
"""Get the collection info for a namespace.

pymongo/synchronous/encryption.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import contextlib
2020
import enum
2121
import socket
22+
import time as time # noqa: PLC0414 # needed in sync version
2223
import uuid
2324
import weakref
2425
from copy import deepcopy
@@ -67,7 +68,7 @@
6768
EncryptedCollectionError,
6869
EncryptionError,
6970
InvalidOperation,
70-
PyMongoError,
71+
NetworkTimeout,
7172
ServerSelectionTimeoutError,
7273
)
7374
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
@@ -80,14 +81,21 @@
8081
from pymongo.synchronous.cursor import Cursor
8182
from pymongo.synchronous.database import Database
8283
from pymongo.synchronous.mongo_client import MongoClient
83-
from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure
84+
from pymongo.synchronous.pool import (
85+
_configured_socket,
86+
_get_timeout_details,
87+
_raise_connection_failure,
88+
)
8489
from pymongo.typings import _DocumentType, _DocumentTypeArg
8590
from pymongo.uri_parser import parse_host
8691
from pymongo.write_concern import WriteConcern
8792

8893
if TYPE_CHECKING:
8994
from pymongocrypt.mongocrypt import MongoCryptKmsContext
9095

96+
from pymongo.pyopenssl_context import _sslConn
97+
from pymongo.typings import _Address
98+
9199

92100
_IS_SYNC = True
93101

@@ -103,6 +111,13 @@
103111
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
104112

105113

114+
def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
115+
try:
116+
return _configured_socket(address, opts)
117+
except Exception as exc:
118+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
119+
120+
106121
@contextlib.contextmanager
107122
def _wrap_encryption_errors() -> Iterator[None]:
108123
"""Context manager to wrap encryption related errors."""
@@ -166,18 +181,22 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
166181
None, # crlfile
167182
False, # allow_invalid_certificates
168183
False, # allow_invalid_hostnames
169-
False,
170-
) # disable_ocsp_endpoint_check
184+
False, # disable_ocsp_endpoint_check
185+
)
171186
# CSOT: set timeout for socket creation.
172187
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
173188
opts = PoolOptions(
174189
connect_timeout=connect_timeout,
175190
socket_timeout=connect_timeout,
176191
ssl_context=ctx,
177192
)
178-
host, port = parse_host(endpoint, _HTTPS_PORT)
193+
address = parse_host(endpoint, _HTTPS_PORT)
194+
sleep_u = kms_context.usleep
195+
if sleep_u:
196+
sleep_sec = float(sleep_u) / 1e6
197+
time.sleep(sleep_sec)
179198
try:
180-
conn = _configured_socket((host, port), opts)
199+
conn = _connect_kms(address, opts)
181200
try:
182201
sendall(conn, message)
183202
while kms_context.bytes_needed > 0:
@@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194213
if not data:
195214
raise OSError("KMS connection closed")
196215
kms_context.feed(data)
197-
# Async raises an OSError instead of returning empty bytes
198-
except OSError as err:
199-
raise OSError("KMS connection closed") from err
200-
except BLOCKING_IO_ERRORS:
201-
raise socket.timeout("timed out") from None
216+
except MongoCryptError:
217+
raise # Propagate MongoCryptError errors directly.
218+
except Exception as exc:
219+
# Wrap I/O errors in PyMongo exceptions.
220+
if isinstance(exc, BLOCKING_IO_ERRORS):
221+
exc = socket.timeout("timed out")
222+
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
202223
finally:
203224
conn.close()
204-
except (PyMongoError, MongoCryptError):
205-
raise # Propagate pymongo errors directly.
206-
except asyncio.CancelledError:
207-
raise
208-
except Exception as error:
209-
# Wrap I/O errors in PyMongo exceptions.
210-
_raise_connection_failure((host, port), error)
225+
except MongoCryptError:
226+
raise # Propagate MongoCryptError errors directly.
227+
except Exception as exc:
228+
remaining = _csot.remaining()
229+
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
230+
raise
231+
# Mark this attempt as failed and defer to libmongocrypt to retry.
232+
try:
233+
kms_context.fail()
234+
except MongoCryptError as final_err:
235+
exc = MongoCryptError(
236+
f"{final_err}, last attempt failed with: {exc}", final_err.code
237+
)
238+
raise exc from final_err
211239

212240
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
213241
"""Get the collection info for a namespace.

test/asynchronous/test_encryption.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import base64
1919
import copy
20+
import http.client
21+
import json
2022
import os
2123
import pathlib
2224
import re
@@ -91,6 +93,7 @@
9193
WriteError,
9294
)
9395
from pymongo.operations import InsertOne, ReplaceOne, UpdateOne
96+
from pymongo.ssl_support import get_ssl_context
9497
from pymongo.write_concern import WriteConcern
9598

9699
_IS_SYNC = False
@@ -1366,9 +1369,8 @@ async def test_04_aws_endpoint_invalid_port(self):
13661369
"key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"),
13671370
"endpoint": "kms.us-east-1.amazonaws.com:12345",
13681371
}
1369-
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx:
1372+
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"):
13701373
await self.client_encryption.create_data_key("aws", master_key=master_key)
1371-
self.assertIsInstance(ctx.exception.cause, AutoReconnect)
13721374

13731375
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
13741376
async def test_05_aws_endpoint_wrong_region(self):
@@ -2853,6 +2855,86 @@ async def test_accepts_trim_factor_0(self):
28532855
assert len(payload) > len(self.payload_defaults)
28542856

28552857

2858+
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests
2859+
class TestKmsRetryProse(AsyncEncryptionIntegrationTest):
2860+
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
2861+
async def asyncSetUp(self):
2862+
await super().asyncSetUp()
2863+
# 1, create client with only tlsCAFile.
2864+
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
2865+
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003"
2866+
providers["gcp"]["endpoint"] = "127.0.0.1:9003"
2867+
kms_tls_opts = {
2868+
p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers
2869+
}
2870+
self.client_encryption = self.create_client_encryption(
2871+
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
2872+
)
2873+
2874+
async def http_post(self, path, data=None):
2875+
# Note, the connection to the mock server needs to be closed after
2876+
# each request because the server is single threaded.
2877+
ctx: ssl.SSLContext = get_ssl_context(
2878+
CLIENT_PEM, # certfile
2879+
None, # passphrase
2880+
CA_PEM, # ca_certs
2881+
None, # crlfile
2882+
False, # allow_invalid_certificates
2883+
False, # allow_invalid_hostnames
2884+
False, # disable_ocsp_endpoint_check
2885+
)
2886+
conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx)
2887+
try:
2888+
if data is not None:
2889+
headers = {"Content-type": "application/json"}
2890+
body = json.dumps(data)
2891+
else:
2892+
headers = {}
2893+
body = None
2894+
conn.request("POST", path, body, headers)
2895+
res = conn.getresponse()
2896+
res.read()
2897+
finally:
2898+
conn.close()
2899+
2900+
async def _test(self, provider, master_key):
2901+
await self.http_post("/reset")
2902+
# Case 1: createDataKey and encrypt with TCP retry
2903+
await self.http_post("/set_failpoint/network", {"count": 1})
2904+
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
2905+
await self.http_post("/set_failpoint/network", {"count": 1})
2906+
await self.client_encryption.encrypt(
2907+
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
2908+
)
2909+
2910+
# Case 2: createDataKey and encrypt with HTTP retry
2911+
await self.http_post("/set_failpoint/http", {"count": 1})
2912+
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
2913+
await self.http_post("/set_failpoint/http", {"count": 1})
2914+
await self.client_encryption.encrypt(
2915+
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
2916+
)
2917+
2918+
# Case 3: createDataKey fails after too many retries
2919+
await self.http_post("/set_failpoint/network", {"count": 4})
2920+
with self.assertRaisesRegex(EncryptionError, "KMS request failed after"):
2921+
await self.client_encryption.create_data_key(provider, master_key=master_key)
2922+
2923+
async def test_kms_retry(self):
2924+
await self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"})
2925+
await self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"})
2926+
await self._test(
2927+
"gcp",
2928+
{
2929+
"projectId": "foo",
2930+
"location": "bar",
2931+
"keyRing": "baz",
2932+
"keyName": "qux",
2933+
"endpoint": "127.0.0.1:9003",
2934+
},
2935+
)
2936+
2937+
28562938
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys
28572939
class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
28582940
@async_client_context.require_no_standalone

0 commit comments

Comments
 (0)