diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 4802c3f54e..1cf165e6a2 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -19,6 +19,7 @@ import contextlib import enum import socket +import time as time # noqa: PLC0414 # needed in sync version import uuid import weakref from copy import deepcopy @@ -63,7 +64,11 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.asynchronous.pool import ( + _configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts @@ -72,7 +77,7 @@ EncryptedCollectionError, EncryptionError, InvalidOperation, - PyMongoError, + NetworkTimeout, ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall @@ -88,6 +93,9 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext + from pymongo.pyopenssl_context import _sslConn + from pymongo.typings import _Address + _IS_SYNC = False @@ -103,6 +111,13 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) +async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: + try: + return await _configured_socket(address, opts) + except Exception as exc: + _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) + + @contextlib.contextmanager def _wrap_encryption_errors() -> Iterator[None]: """Context manager to wrap encryption related errors.""" @@ -166,8 +181,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: None, # crlfile False, # allow_invalid_certificates False, # allow_invalid_hostnames - False, - ) # disable_ocsp_endpoint_check + False, # disable_ocsp_endpoint_check + ) # CSOT: set timeout for socket creation. connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001) opts = PoolOptions( @@ -175,9 +190,13 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: socket_timeout=connect_timeout, ssl_context=ctx, ) - host, port = parse_host(endpoint, _HTTPS_PORT) + address = parse_host(endpoint, _HTTPS_PORT) + sleep_u = kms_context.usleep + if sleep_u: + sleep_sec = float(sleep_u) / 1e6 + await asyncio.sleep(sleep_sec) try: - conn = await _configured_socket((host, port), opts) + conn = await _connect_kms(address, opts) try: await async_sendall(conn, message) while kms_context.bytes_needed > 0: @@ -194,20 +213,29 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: if not data: raise OSError("KMS connection closed") kms_context.feed(data) - # Async raises an OSError instead of returning empty bytes - except OSError as err: - raise OSError("KMS connection closed") from err - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None + except MongoCryptError: + raise # Propagate MongoCryptError errors directly. + except Exception as exc: + # Wrap I/O errors in PyMongo exceptions. + if isinstance(exc, BLOCKING_IO_ERRORS): + exc = socket.timeout("timed out") + _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) finally: conn.close() - except (PyMongoError, MongoCryptError): - raise # Propagate pymongo errors directly. - except asyncio.CancelledError: - raise - except Exception as error: - # Wrap I/O errors in PyMongo exceptions. - _raise_connection_failure((host, port), error) + except MongoCryptError: + raise # Propagate MongoCryptError errors directly. + except Exception as exc: + remaining = _csot.remaining() + if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0): + raise + # Mark this attempt as failed and defer to libmongocrypt to retry. + try: + kms_context.fail() + except MongoCryptError as final_err: + exc = MongoCryptError( + f"{final_err}, last attempt failed with: {exc}", final_err.code + ) + raise exc from final_err async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 09d0c0f2fd..ef49855059 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -19,6 +19,7 @@ import contextlib import enum import socket +import time as time # noqa: PLC0414 # needed in sync version import uuid import weakref from copy import deepcopy @@ -67,7 +68,7 @@ EncryptedCollectionError, EncryptionError, InvalidOperation, - PyMongoError, + NetworkTimeout, ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall @@ -80,7 +81,11 @@ from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.synchronous.pool import ( + _configured_socket, + _get_timeout_details, + _raise_connection_failure, +) from pymongo.typings import _DocumentType, _DocumentTypeArg from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern @@ -88,6 +93,9 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext + from pymongo.pyopenssl_context import _sslConn + from pymongo.typings import _Address + _IS_SYNC = True @@ -103,6 +111,13 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) +def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: + try: + return _configured_socket(address, opts) + except Exception as exc: + _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) + + @contextlib.contextmanager def _wrap_encryption_errors() -> Iterator[None]: """Context manager to wrap encryption related errors.""" @@ -166,8 +181,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: None, # crlfile False, # allow_invalid_certificates False, # allow_invalid_hostnames - False, - ) # disable_ocsp_endpoint_check + False, # disable_ocsp_endpoint_check + ) # CSOT: set timeout for socket creation. connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001) opts = PoolOptions( @@ -175,9 +190,13 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: socket_timeout=connect_timeout, ssl_context=ctx, ) - host, port = parse_host(endpoint, _HTTPS_PORT) + address = parse_host(endpoint, _HTTPS_PORT) + sleep_u = kms_context.usleep + if sleep_u: + sleep_sec = float(sleep_u) / 1e6 + time.sleep(sleep_sec) try: - conn = _configured_socket((host, port), opts) + conn = _connect_kms(address, opts) try: sendall(conn, message) while kms_context.bytes_needed > 0: @@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: if not data: raise OSError("KMS connection closed") kms_context.feed(data) - # Async raises an OSError instead of returning empty bytes - except OSError as err: - raise OSError("KMS connection closed") from err - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None + except MongoCryptError: + raise # Propagate MongoCryptError errors directly. + except Exception as exc: + # Wrap I/O errors in PyMongo exceptions. + if isinstance(exc, BLOCKING_IO_ERRORS): + exc = socket.timeout("timed out") + _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) finally: conn.close() - except (PyMongoError, MongoCryptError): - raise # Propagate pymongo errors directly. - except asyncio.CancelledError: - raise - except Exception as error: - # Wrap I/O errors in PyMongo exceptions. - _raise_connection_failure((host, port), error) + except MongoCryptError: + raise # Propagate MongoCryptError errors directly. + except Exception as exc: + remaining = _csot.remaining() + if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0): + raise + # Mark this attempt as failed and defer to libmongocrypt to retry. + try: + kms_context.fail() + except MongoCryptError as final_err: + exc = MongoCryptError( + f"{final_err}, last attempt failed with: {exc}", final_err.code + ) + raise exc from final_err def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 21cd5e2666..559b06ddf4 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -17,6 +17,8 @@ import base64 import copy +import http.client +import json import os import pathlib import re @@ -91,6 +93,7 @@ WriteError, ) from pymongo.operations import InsertOne, ReplaceOne, UpdateOne +from pymongo.ssl_support import get_ssl_context from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -1366,9 +1369,8 @@ async def test_04_aws_endpoint_invalid_port(self): "key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-1.amazonaws.com:12345", } - with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx: + with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"): await self.client_encryption.create_data_key("aws", master_key=master_key) - self.assertIsInstance(ctx.exception.cause, AutoReconnect) @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") async def test_05_aws_endpoint_wrong_region(self): @@ -2853,6 +2855,86 @@ async def test_accepts_trim_factor_0(self): assert len(payload) > len(self.payload_defaults) +# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests +class TestKmsRetryProse(AsyncEncryptionIntegrationTest): + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") + async def asyncSetUp(self): + await super().asyncSetUp() + # 1, create client with only tlsCAFile. + providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS) + providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003" + providers["gcp"]["endpoint"] = "127.0.0.1:9003" + kms_tls_opts = { + p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers + } + self.client_encryption = self.create_client_encryption( + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts + ) + + async def http_post(self, path, data=None): + # Note, the connection to the mock server needs to be closed after + # each request because the server is single threaded. + ctx: ssl.SSLContext = get_ssl_context( + CLIENT_PEM, # certfile + None, # passphrase + CA_PEM, # ca_certs + None, # crlfile + False, # allow_invalid_certificates + False, # allow_invalid_hostnames + False, # disable_ocsp_endpoint_check + ) + conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx) + try: + if data is not None: + headers = {"Content-type": "application/json"} + body = json.dumps(data) + else: + headers = {} + body = None + conn.request("POST", path, body, headers) + res = conn.getresponse() + res.read() + finally: + conn.close() + + async def _test(self, provider, master_key): + await self.http_post("/reset") + # Case 1: createDataKey and encrypt with TCP retry + await self.http_post("/set_failpoint/network", {"count": 1}) + key_id = await self.client_encryption.create_data_key(provider, master_key=master_key) + await self.http_post("/set_failpoint/network", {"count": 1}) + await self.client_encryption.encrypt( + 123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id + ) + + # Case 2: createDataKey and encrypt with HTTP retry + await self.http_post("/set_failpoint/http", {"count": 1}) + key_id = await self.client_encryption.create_data_key(provider, master_key=master_key) + await self.http_post("/set_failpoint/http", {"count": 1}) + await self.client_encryption.encrypt( + 123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id + ) + + # Case 3: createDataKey fails after too many retries + await self.http_post("/set_failpoint/network", {"count": 4}) + with self.assertRaisesRegex(EncryptionError, "KMS request failed after"): + await self.client_encryption.create_data_key(provider, master_key=master_key) + + async def test_kms_retry(self): + await self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"}) + await self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"}) + await self._test( + "gcp", + { + "projectId": "foo", + "location": "bar", + "keyRing": "baz", + "keyName": "qux", + "endpoint": "127.0.0.1:9003", + }, + ) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest): @async_client_context.require_no_standalone diff --git a/test/test_encryption.py b/test/test_encryption.py index 18e21fe6a7..7a9929b7fd 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -17,6 +17,8 @@ import base64 import copy +import http.client +import json import os import pathlib import re @@ -88,6 +90,7 @@ WriteError, ) from pymongo.operations import InsertOne, ReplaceOne, UpdateOne +from pymongo.ssl_support import get_ssl_context from pymongo.synchronous import encryption from pymongo.synchronous.encryption import Algorithm, ClientEncryption, QueryType from pymongo.synchronous.mongo_client import MongoClient @@ -1360,9 +1363,8 @@ def test_04_aws_endpoint_invalid_port(self): "key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-1.amazonaws.com:12345", } - with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx: + with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"): self.client_encryption.create_data_key("aws", master_key=master_key) - self.assertIsInstance(ctx.exception.cause, AutoReconnect) @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_05_aws_endpoint_wrong_region(self): @@ -2835,6 +2837,86 @@ def test_accepts_trim_factor_0(self): assert len(payload) > len(self.payload_defaults) +# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests +class TestKmsRetryProse(EncryptionIntegrationTest): + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") + def setUp(self): + super().setUp() + # 1, create client with only tlsCAFile. + providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS) + providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003" + providers["gcp"]["endpoint"] = "127.0.0.1:9003" + kms_tls_opts = { + p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers + } + self.client_encryption = self.create_client_encryption( + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts + ) + + def http_post(self, path, data=None): + # Note, the connection to the mock server needs to be closed after + # each request because the server is single threaded. + ctx: ssl.SSLContext = get_ssl_context( + CLIENT_PEM, # certfile + None, # passphrase + CA_PEM, # ca_certs + None, # crlfile + False, # allow_invalid_certificates + False, # allow_invalid_hostnames + False, # disable_ocsp_endpoint_check + ) + conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx) + try: + if data is not None: + headers = {"Content-type": "application/json"} + body = json.dumps(data) + else: + headers = {} + body = None + conn.request("POST", path, body, headers) + res = conn.getresponse() + res.read() + finally: + conn.close() + + def _test(self, provider, master_key): + self.http_post("/reset") + # Case 1: createDataKey and encrypt with TCP retry + self.http_post("/set_failpoint/network", {"count": 1}) + key_id = self.client_encryption.create_data_key(provider, master_key=master_key) + self.http_post("/set_failpoint/network", {"count": 1}) + self.client_encryption.encrypt( + 123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id + ) + + # Case 2: createDataKey and encrypt with HTTP retry + self.http_post("/set_failpoint/http", {"count": 1}) + key_id = self.client_encryption.create_data_key(provider, master_key=master_key) + self.http_post("/set_failpoint/http", {"count": 1}) + self.client_encryption.encrypt( + 123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id + ) + + # Case 3: createDataKey fails after too many retries + self.http_post("/set_failpoint/network", {"count": 4}) + with self.assertRaisesRegex(EncryptionError, "KMS request failed after"): + self.client_encryption.create_data_key(provider, master_key=master_key) + + def test_kms_retry(self): + self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"}) + self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"}) + self._test( + "gcp", + { + "projectId": "foo", + "location": "bar", + "keyRing": "baz", + "keyName": "qux", + "endpoint": "127.0.0.1:9003", + }, + ) + + # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys class TestAutomaticDecryptionKeys(EncryptionIntegrationTest): @client_context.require_no_standalone