diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index b7fde834a2..f450ea23cc 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,7 +17,8 @@ import asyncio import sys -from test.utils import async_get_pool, delay, one +from test.asynchronous.utils import async_get_pool +from test.utils_shared import delay, one sys.path[0:0] = [""] diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 7172152d69..904674db16 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -30,7 +30,7 @@ async_client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 86568b666b..5573c3987f 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 08da00cc1e..4025c13730 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index acc815c8a4..f9678b11e2 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -60,14 +60,16 @@ unittest, ) from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.asynchronous.utils import ( + async_get_pool, + async_wait_until, + asyncAssertRaisesExactly, +) from test.test_binary import BinaryData -from test.utils import ( +from test.utils_shared import ( NTHREADS, CMAPListener, FunctionCallRecorder, - async_get_pool, - async_wait_until, - asyncAssertRaisesExactly, delay, gevent_monkey_patched, is_greenthread_patched, diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 282009f554..f8b9465b09 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -25,7 +25,7 @@ async_client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index d7fd85b168..05e548c79e 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.asynchronous.helpers import anext diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index beb58012a8..00ed020d88 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.asynchronous.utils import async_get_pool, async_is_mongos from typing import Any, Iterable, no_type_check from pymongo.asynchronous.database import AsyncDatabase @@ -33,12 +34,10 @@ AsyncUnitTest, async_client_context, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - async_get_pool, - async_is_mongos, async_wait_until, ) diff --git a/test/asynchronous/test_comment.py b/test/asynchronous/test_comment.py index be3626a8b8..d3ddaf2b65 100644 --- a/test/asynchronous/test_comment.py +++ b/test/asynchronous/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.asynchronous.command_cursor import AsyncCommandCursor diff --git a/test/asynchronous/test_concurrency.py b/test/asynchronous/test_concurrency.py index 1683b8413b..193ecf05c8 100644 --- a/test/asynchronous/test_concurrency.py +++ b/test/asynchronous/test_concurrency.py @@ -18,7 +18,7 @@ import asyncio import time from test.asynchronous import AsyncIntegrationTest, async_client_context -from test.utils import delay +from test.utils_shared import delay _IS_SYNC = False diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index cdf4887ba3..359346d984 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -20,17 +20,15 @@ import sys import time from pathlib import Path +from test.asynchronous.utils import async_get_pool, async_get_pools sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask -from test.utils import ( +from test.utils_shared import ( CMAPListener, - async_client_context, - async_get_pool, - async_get_pools, async_wait_until, camel_to_snake, ) diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 7c11742a90..92c750c4fe 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.asynchronous.utils import async_ensure_all_connected sys.path[0:0] = [""] @@ -25,9 +26,8 @@ unittest, ) from test.asynchronous.helpers import async_repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, - async_ensure_all_connected, ) from bson import SON diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d843ffb4aa..90d5e7801e 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/asynchronous/test_data_lake.py b/test/asynchronous/test_data_lake.py index e67782ad3f..689bf38534 100644 --- a/test/asynchronous/test_data_lake.py +++ b/test/asynchronous/test_data_lake.py @@ -25,7 +25,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 2bbf763ab3..b2ddd4122d 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -26,7 +26,7 @@ from test import unittest from test.asynchronous import AsyncIntegrationTest, async_client_context from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index c3c2bb1a6c..b3de2c5a4d 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -26,25 +26,32 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + AsyncUnitTest, + async_client_context, + unittest, +) from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.asynchronous.utils import ( + async_get_pool, +) +from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, async_barrier_wait, - async_client_context, async_create_barrier, - async_get_pool, async_wait_until, server_name_to_type, ) from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import AsyncMongoClient, common, monitoring +from pymongo import common, monitoring from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.errors import ( @@ -291,7 +298,7 @@ async def test_ignore_stale_connection_errors(self): if not _IS_SYNC and sys.version_info < (3, 11): self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 - barrier = async_create_barrier(N_TASKS, timeout=30) + barrier = async_create_barrier(N_TASKS) client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index e24e0fb5ce..a622062fec 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -29,7 +29,7 @@ async_client_context, unittest, ) -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 335aa9d81c..000d98a111 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -64,7 +64,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, diff --git a/test/asynchronous/test_examples.py b/test/asynchronous/test_examples.py index 7fea9d41af..1312f1e215 100644 --- a/test/asynchronous/test_examples.py +++ b/test/asynchronous/test_examples.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_wait_until +from test.utils_shared import async_wait_until import pymongo from pymongo.asynchronous.helpers import anext diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index affdacde91..3f864367de 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index b1c1e754ff..f886601f36 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -28,7 +28,8 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_joinall, one +from test.asynchronous.utils import async_joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_gridfs_bucket.py b/test/asynchronous/test_gridfs_bucket.py index 5d1cf5beff..29877ee9c4 100644 --- a/test/asynchronous/test_gridfs_bucket.py +++ b/test/asynchronous/test_gridfs_bucket.py @@ -29,7 +29,8 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_joinall, joinall, one +from test.asynchronous.utils import async_joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py index ff595a8144..aa8a205021 100644 --- a/test/asynchronous/test_heartbeat_monitoring.py +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -16,11 +16,12 @@ from __future__ import annotations import sys +from test.asynchronous.utils import AsyncMockPool sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest -from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until +from test.utils_shared import HeartbeatEventListener, async_wait_until from pymongo.asynchronous.monitor import Monitor from pymongo.errors import ConnectionFailure diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py index c155047089..4b218de130 100644 --- a/test/asynchronous/test_index_management.py +++ b/test/asynchronous/test_index_management.py @@ -29,7 +29,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import AllowListEventListener, OvertCommandListener +from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index fd50841c87..127fdfd24d 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -23,6 +23,7 @@ import threading from asyncio import Event from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.asynchronous.utils import async_get_pool import pytest @@ -30,8 +31,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( - async_get_pool, +from test.utils_shared import ( async_wait_until, create_async_event, ) diff --git a/test/asynchronous/test_max_staleness.py b/test/asynchronous/test_max_staleness.py index 7dbf17021f..b6e15f9158 100644 --- a/test/asynchronous/test_max_staleness.py +++ b/test/asynchronous/test_max_staleness.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest -from test.utils_selection_tests import create_selection_tests +from test.asynchronous.utils_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index 0bc6a405f4..97170aa9e0 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest from test.asynchronous.pymongo_mocks import AsyncMockClient -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector diff --git a/test/asynchronous/test_monitor.py b/test/asynchronous/test_monitor.py index 2705fbda3b..195f6f9fac 100644 --- a/test/asynchronous/test_monitor.py +++ b/test/asynchronous/test_monitor.py @@ -25,10 +25,10 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest -from test.utils import ( - ServerAndTopologyEventListener, +from test.asynchronous.utils import ( async_wait_until, ) +from test.utils_shared import ServerAndTopologyEventListener from pymongo.periodic_executor import _EXECUTORS diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index eaad60beac..a7d56a8cf7 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 812b5a48e0..8213c794fe 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -21,6 +21,7 @@ import socket import sys import time +from test.asynchronous.utils import async_get_pool, async_joinall from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,7 +34,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.helpers import ConcurrentRunner -from test.utils import async_get_pool, async_joinall, delay +from test.utils_shared import delay from pymongo.asynchronous.pool import Pool, PoolOptions from pymongo.socket_checker import SocketChecker diff --git a/test/asynchronous/test_read_concern.py b/test/asynchronous/test_read_concern.py index fbc07a5c36..8659bf80b2 100644 --- a/test/asynchronous/test_read_concern.py +++ b/test/asynchronous/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py index 077bc21eaf..5bea174058 100644 --- a/test/asynchronous/test_read_preferences.py +++ b/test/asynchronous/test_read_preferences.py @@ -33,7 +33,7 @@ connected, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, async_wait_until, one, diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py index 3fb13ba194..86f79fd28d 100644 --- a/test/asynchronous/test_read_write_concern_spec.py +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -25,7 +25,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo import DESCENDING from pymongo.asynchronous.mongo_client import AsyncMongoClient diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index bde7a9f2ee..10d9e738b4 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point from pymongo.errors import AutoReconnect @@ -31,10 +32,9 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, - async_set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index 738ce04192..2f6cb2b575 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point sys.path[0:0] = [""] @@ -30,12 +31,11 @@ unittest, ) from test.asynchronous.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, OvertCommandListener, - async_set_fail_point, ) from test.version import Version diff --git a/test/asynchronous/test_sdam_monitoring_spec.py b/test/asynchronous/test_sdam_monitoring_spec.py index 8b0ec63cfe..71ec6c6b46 100644 --- a/test/asynchronous/test_sdam_monitoring_spec.py +++ b/test/asynchronous/test_sdam_monitoring_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, async_wait_until, server_name_to_type, diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py index f0451841cd..f98a05ee91 100644 --- a/test/asynchronous/test_server_selection.py +++ b/test/asynchronous/test_server_selection.py @@ -31,17 +31,18 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils import async_wait_until from test.asynchronous.utils_selection_tests import ( create_selection_tests, - get_addresses, get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, make_server_description, ) -from test.utils import ( - EventListener, +from test.utils_shared import ( FunctionCallRecorder, OvertCommandListener, - async_wait_until, ) _IS_SYNC = False diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index e2ae92a27c..3fe448d4dd 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -23,10 +23,9 @@ from test.asynchronous.helpers import ConcurrentRunner from test.asynchronous.utils_selection_tests import create_topology from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, - async_get_pool, async_wait_until, ) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 568d392cd5..4431cbcb16 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -30,14 +30,13 @@ from test.asynchronous import ( AsyncIntegrationTest, - AsyncPyMongoTestCase, AsyncUnitTest, SkipTest, async_client_context, unittest, ) from test.asynchronous.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( EventListener, HeartbeatEventListener, OvertCommandListener, diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index 763c80e665..bf7807eb97 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -18,12 +18,13 @@ import asyncio import sys import time +from test.utils_shared import FunctionCallRecorder from typing import Any sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest -from test.utils import FunctionCallRecorder, async_wait_until +from test.asynchronous.utils import async_wait_until import pymongo from pymongo import common diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py index d50bb220b1..d920b77ac2 100644 --- a/test/asynchronous/test_ssl.py +++ b/test/asynchronous/test_ssl.py @@ -32,7 +32,7 @@ remove_all_users, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, cat_files, diff --git a/test/asynchronous/test_streaming_protocol.py b/test/asynchronous/test_streaming_protocol.py index fd890d29fb..1206e7b2fa 100644 --- a/test/asynchronous/test_streaming_protocol.py +++ b/test/asynchronous/test_streaming_protocol.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( HeartbeatEventListener, ServerEventListener, async_wait_until, diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 5f75746a4d..884110cd45 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, async_wait_until, ) diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py index 7e9a79da90..46e62d5c14 100644 --- a/test/asynchronous/test_versioned_api_integration.py +++ b/test/asynchronous/test_versioned_api_integration.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo.server_api import ServerApi diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index ce0b9979e2..886b31e4a6 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ client_knobs, unittest, ) +from test.asynchronous.utils import async_get_pool from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, @@ -49,8 +50,7 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( - async_get_pool, +from test.utils_shared import ( async_wait_until, camel_to_snake, camel_to_snake_args, diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py new file mode 100644 index 0000000000..4b68595397 --- /dev/null +++ b/test/asynchronous/utils.py @@ -0,0 +1,211 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing pymongo that require synchronization.""" +from __future__ import annotations + +import asyncio +import contextlib +import random +import threading # Used in the synchronized version of this file +import time +from asyncio import iscoroutinefunction + +from bson.son import SON +from pymongo import AsyncMongoClient +from pymongo.errors import ConfigurationError +from pymongo.hello import HelloCompat +from pymongo.lock import _async_create_lock +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference +from pymongo.server_selectors import any_server_selector, writable_server_selector +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration + +_IS_SYNC = False + + +async def async_get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = await client._get_topology() + server = await topology._select_server(writable_server_selector, _Op.TEST) + return server.pool + + +async def async_get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in await (await client._get_topology()).select_servers( + any_server_selector, _Op.TEST + ) + ] + + +async def async_wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. + + E.g.: + + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') + + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). + + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() + if retval: + return retval + + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) + + await asyncio.sleep(interval) + + +async def async_is_mongos(client): + res = await client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" + + +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + +async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): + """ + Unlike the standard assertRaises, this checks that a function raises a + specific class of exception, and not a subclass. E.g., check that + MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. + """ + try: + await fn(*args, **kwargs) + except Exception as e: + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" + else: + raise AssertionError("%s not raised" % cls) + + +async def async_set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t + else: + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + +class AsyncMockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + + def close_conn(self, reason): + pass + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class AsyncMockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _async_create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + yield AsyncMockConnection() + + async def checkin(self, *args, **kwargs): + pass + + async def _reset(self, service_id=None): + async with self._lock: + self.gen.inc(service_id) + + async def ready(self): + pass + + async def reset(self, service_id=None, interrupt_connections=False): + await self._reset() + + async def reset_without_pause(self): + await self._reset() + + async def close(self): + await self._reset() + + async def update_is_writable(self, is_writable): + pass + + async def remove_stale_sockets(self, *args, **kwargs): + pass diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index 71e287569a..d6b92fadb4 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -19,17 +19,18 @@ import os import sys from test.asynchronous import AsyncPyMongoTestCase +from test.asynchronous.utils import AsyncMockPool sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils import AsyncMockPool, parse_read_preference from test.utils_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, ) +from test.utils_shared import parse_read_preference from bson import json_util from pymongo.asynchronous.settings import TopologySettings diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f1c6deb690..c83636a734 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -24,7 +24,7 @@ from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.asynchronous.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_shared import ( CMAPListener, CompareType, EventListener, diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7a78f3d2f6..a5334d79bd 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.unified_format import generate_test_classes -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from bson import SON from pymongo import MongoClient diff --git a/test/test_auth.py b/test/test_auth.py index 345d16121b..27f6743fae 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -30,7 +30,7 @@ client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/test_bulk.py b/test/test_bulk.py index 6a72bddfc0..8a863cc49b 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, remove_all_users, unittest -from test.utils import wait_until +from test.utils_shared import wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 4ed21f55cf..e50f4667f6 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/test_client.py b/test/test_client.py index 8e99866cc8..a340263937 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -61,17 +61,19 @@ from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( + assertRaisesExactly, + get_pool, + wait_until, +) +from test.utils_shared import ( NTHREADS, CMAPListener, FunctionCallRecorder, - assertRaisesExactly, delay, - get_pool, gevent_monkey_patched, is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index f8d92668ea..b00b2c1b03 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -25,7 +25,7 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch diff --git a/test/test_collation.py b/test/test_collation.py index 06436f0638..5425551dc6 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.collation import ( diff --git a/test/test_collection.py b/test/test_collection.py index 8a862646eb..75c11383d0 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.utils import get_pool, is_mongos from typing import Any, Iterable, no_type_check from pymongo.synchronous.database import Database @@ -33,12 +34,10 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - get_pool, - is_mongos, wait_until, ) diff --git a/test/test_comment.py b/test/test_comment.py index 9f9bf98640..b6c17c14fe 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test import IntegrationTest, client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.operations import IndexModel diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 3987f2b68b..1405824453 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -20,17 +20,15 @@ import sys import time from pathlib import Path +from test.utils import get_pool, get_pools sys.path[0:0] = [""] -from test import IntegrationTest, client_knobs, unittest +from test import IntegrationTest, client_context, client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import ( +from test.utils_shared import ( CMAPListener, camel_to_snake, - client_context, - get_pool, - get_pools, wait_until, ) from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 9cac633301..d923a477b5 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.utils import ensure_all_connected sys.path[0:0] = [""] @@ -25,9 +26,8 @@ unittest, ) from test.helpers import repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, - ensure_all_connected, ) from bson import SON diff --git a/test/test_cursor.py b/test/test_cursor.py index 84e431f8cb..a9cbe99942 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/test_data_lake.py b/test/test_data_lake.py index c8b76eb1ca..d6d2007007 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -25,7 +25,7 @@ from test import IntegrationTest, UnitTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) diff --git a/test/test_database.py b/test/test_database.py index 48cca921b1..4c09b421cf 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_context, unittest from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, wait_until, diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index bfe0b24387..00021310c9 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -26,25 +26,32 @@ sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + UnitTest, + client_context, + unittest, +) from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( + get_pool, +) +from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, barrier_wait, - client_context, create_barrier, - get_pool, server_name_to_type, wait_until, ) from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import MongoClient, common, monitoring +from pymongo import common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -291,7 +298,7 @@ def test_ignore_stale_connection_errors(self): if not _IS_SYNC and sys.version_info < (3, 11): self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 - barrier = create_barrier(N_TASKS, timeout=30) + barrier = create_barrier(N_TASKS) client = self.rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. diff --git a/test/test_dns.py b/test/test_dns.py index 6f4736fd5e..71326ae49e 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -29,7 +29,7 @@ client_context, unittest, ) -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError diff --git a/test/test_encryption.py b/test/test_encryption.py index 9224310144..6efb167442 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -63,7 +63,7 @@ ) from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, diff --git a/test/test_examples.py b/test/test_examples.py index 9bcc276248..ef06a77b9a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import wait_until +from test.utils_shared import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure diff --git a/test/test_fork.py b/test/test_fork.py index 1a89159435..fe88d778d2 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest -from test.utils import is_greenthread_patched +from test.utils_shared import is_greenthread_patched from bson.objectid import ObjectId diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 6534bc11bf..0baeb5ae19 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.errors import NoFile diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 47e38141b2..75342ee437 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -28,7 +28,8 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index e7486cb237..d68c9f6ba2 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -29,7 +29,8 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 0523d0ba4d..7864caf6e1 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -16,11 +16,12 @@ from __future__ import annotations import sys +from test.utils import MockPool sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, wait_until +from test.utils_shared import HeartbeatEventListener, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat diff --git a/test/test_index_management.py b/test/test_index_management.py index e4b931cf00..3a2b17cd3d 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -29,7 +29,7 @@ from test import IntegrationTest, PyMongoTestCase, unittest from test.unified_format import generate_test_classes -from test.utils import AllowListEventListener, OvertCommandListener +from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 7db19b46b5..d7f1d596cc 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -23,6 +23,7 @@ import threading from asyncio import Event from test.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.utils import get_pool import pytest @@ -30,9 +31,8 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( create_event, - get_pool, wait_until, ) diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index ca2f3cfd1e..8c31854343 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -26,7 +26,7 @@ from test import MockClientTest, client_context, connected, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector diff --git a/test/test_monitor.py b/test/test_monitor.py index 0fb7eb9cae..25620a99e8 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -26,9 +26,9 @@ from test import IntegrationTest, client_context, connected, unittest from test.utils import ( - ServerAndTopologyEventListener, wait_until, ) +from test.utils_shared import ServerAndTopologyEventListener from pymongo.periodic_executor import _EXECUTORS diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 670558c0a0..ae3e50db77 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, wait_until, diff --git a/test/test_objectid.py b/test/test_objectid.py index 26670832f6..d7db7229ea 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] from test import SkipTest, unittest -from test.utils import oid_generated_on_process +from test.utils_shared import oid_generated_on_process from bson.errors import InvalidId from bson.objectid import _MAX_COUNTER_VALUE, ObjectId diff --git a/test/test_pooling.py b/test/test_pooling.py index 1755365f80..44e8c4afe5 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -21,6 +21,7 @@ import socket import sys import time +from test.utils import get_pool, joinall from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,7 +34,7 @@ from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner -from test.utils import delay, get_pool, joinall +from test.utils_shared import delay from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 8ec9865eaa..62b2491475 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 0d38f3f00d..e754c896ad 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -33,7 +33,7 @@ connected, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, one, wait_until, diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 8543991f72..383dc70902 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo import DESCENDING from pymongo.errors import ( diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 4c23d71b69..3371543f27 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -21,7 +21,7 @@ from test import MockClientTest, client_context, client_knobs, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo import ReadPreference from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 9c3f6b170f..7ae4c41e70 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.utils import set_fail_point from pymongo.errors import AutoReconnect @@ -31,10 +32,9 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, - set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 07bd1db0ba..b099820a45 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.utils import set_fail_point sys.path[0:0] = [""] @@ -30,12 +31,11 @@ unittest, ) from test.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, OvertCommandListener, - set_fail_point, ) from test.version import Version diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 6a53c062cc..2167e561cf 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, client_knobs, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, server_name_to_type, wait_until, diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 3e7f9a8671..aec8e2e47a 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -31,18 +31,19 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( - EventListener, - FunctionCallRecorder, - OvertCommandListener, - wait_until, -) +from test.utils import wait_until from test.utils_selection_tests import ( create_selection_tests, - get_addresses, get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, make_server_description, ) +from test.utils_shared import ( + FunctionCallRecorder, + OvertCommandListener, +) _IS_SYNC = True diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 7ccd4b529e..4aad34050c 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -21,13 +21,12 @@ from pathlib import Path from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_selection_tests import create_topology +from test.utils_shared import ( CMAPListener, OvertCommandListener, - get_pool, wait_until, ) -from test.utils_selection_tests import create_topology from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node diff --git a/test/test_session.py b/test/test_session.py index e80ab41896..905539a1f8 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -30,14 +30,13 @@ from test import ( IntegrationTest, - PyMongoTestCase, SkipTest, UnitTest, client_context, unittest, ) from test.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( EventListener, HeartbeatEventListener, OvertCommandListener, diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 86fad6d90e..6812465074 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -18,12 +18,13 @@ import asyncio import sys import time +from test.utils_shared import FunctionCallRecorder from typing import Any sys.path[0:0] = [""] from test import PyMongoTestCase, client_knobs, unittest -from test.utils import FunctionCallRecorder, wait_until +from test.utils import wait_until import pymongo from pymongo import common diff --git a/test/test_ssl.py b/test/test_ssl.py index 7d6c3f7cd1..a66fe21be5 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -32,7 +32,7 @@ remove_all_users, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, cat_files, diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 894e89e208..acf7610c94 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( HeartbeatEventListener, ServerEventListener, wait_until, diff --git a/test/test_topology.py b/test/test_topology.py index 86aa87c2cc..22e94739ee 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -23,7 +23,8 @@ from test import client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, wait_until +from test.utils import MockPool +from test.utils_shared import wait_until from bson.objectid import ObjectId from pymongo import common diff --git a/test/test_transactions.py b/test/test_transactions.py index 7a8dcd0f00..80b3e3765e 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, wait_until, ) diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py index 502198576a..0066ecd977 100644 --- a/test/test_versioned_api_integration.py +++ b/test/test_versioned_api_integration.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo.server_api import ServerApi diff --git a/test/unified_format.py b/test/unified_format.py index 682a6105f3..471a067bee 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -48,10 +48,10 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( +from test.utils import get_pool +from test.utils_shared import ( camel_to_snake, camel_to_snake_args, - get_pool, parse_spec_options, prepare_spec_arguments, snake_to_camel, diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py index 0c685366f4..009c5c7e28 100644 --- a/test/unified_format_shared.py +++ b/test/unified_format_shared.py @@ -35,7 +35,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from test.utils_shared import CMAPListener, camel_to_snake, parse_collection_options from typing import Any, Union from bson import ( diff --git a/test/utils.py b/test/utils.py index ae316d0387..1459a8fba7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -12,476 +12,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for testing pymongo""" +"""Utilities for testing pymongo that require synchronization.""" from __future__ import annotations import asyncio import contextlib -import copy -import functools -import os import random -import re -import shutil -import sys -import threading +import threading # Used in the synchronized version of this file import time -import unittest -import warnings from asyncio import iscoroutinefunction -from collections import abc, defaultdict -from functools import partial -from test import client_context, db_pwd, db_user -from test.asynchronous import async_client_context -from typing import Any, List -from bson import json_util -from bson.objectid import ObjectId from bson.son import SON -from pymongo import AsyncMongoClient, monitoring, operations, read_preferences -from pymongo._asyncio_task import create_task -from pymongo.cursor_shared import CursorType -from pymongo.errors import ConfigurationError, OperationFailure +from pymongo import MongoClient +from pymongo.errors import ConfigurationError from pymongo.hello import HelloCompat -from pymongo.helpers_shared import _SENSITIVE_COMMANDS -from pymongo.lock import _async_create_lock, _create_lock -from pymongo.monitoring import ( - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, -) +from pymongo.lock import _create_lock from pymongo.operations import _Op -from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import any_server_selector, writable_server_selector -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.uri_parser import parse_uri -from pymongo.write_concern import WriteConcern -IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) +_IS_SYNC = True -class BaseListener: - def __init__(self): - self.events = [] - - def reset(self): - self.events = [] - - def add_event(self, event): - self.events.append(event) - - def event_count(self, event_type): - return len(self.events_by_type(event_type)) - - def events_by_type(self, event_type): - """Return the matching events by event class. - - event_type can be a single class or a tuple of classes. - """ - return self.matching(lambda e: isinstance(e, event_type)) - - def matching(self, matcher): - """Return the matching events.""" - return [event for event in self.events[:] if matcher(event)] - - def wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") - - async def async_wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - await async_wait_until( - lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" - ) - - -class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): - def connection_created(self, event): - assert isinstance(event, ConnectionCreatedEvent) - self.add_event(event) - - def connection_ready(self, event): - assert isinstance(event, ConnectionReadyEvent) - self.add_event(event) - - def connection_closed(self, event): - assert isinstance(event, ConnectionClosedEvent) - self.add_event(event) - - def connection_check_out_started(self, event): - assert isinstance(event, ConnectionCheckOutStartedEvent) - self.add_event(event) - - def connection_check_out_failed(self, event): - assert isinstance(event, ConnectionCheckOutFailedEvent) - self.add_event(event) - - def connection_checked_out(self, event): - assert isinstance(event, ConnectionCheckedOutEvent) - self.add_event(event) - - def connection_checked_in(self, event): - assert isinstance(event, ConnectionCheckedInEvent) - self.add_event(event) - - def pool_created(self, event): - assert isinstance(event, PoolCreatedEvent) - self.add_event(event) - - def pool_ready(self, event): - assert isinstance(event, PoolReadyEvent) - self.add_event(event) - - def pool_cleared(self, event): - assert isinstance(event, PoolClearedEvent) - self.add_event(event) - - def pool_closed(self, event): - assert isinstance(event, PoolClosedEvent) - self.add_event(event) - - -class EventListener(BaseListener, monitoring.CommandListener): - def __init__(self): - super().__init__() - self.results = defaultdict(list) - - @property - def started_events(self) -> List[monitoring.CommandStartedEvent]: - return self.results["started"] - - @property - def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: - return self.results["succeeded"] - - @property - def failed_events(self) -> List[monitoring.CommandFailedEvent]: - return self.results["failed"] - - def started(self, event: monitoring.CommandStartedEvent) -> None: - self.started_events.append(event) - self.add_event(event) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - self.succeeded_events.append(event) - self.add_event(event) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - self.failed_events.append(event) - self.add_event(event) - - def started_command_names(self) -> List[str]: - """Return list of command names started.""" - return [event.command_name for event in self.started_events] - - def reset(self) -> None: - """Reset the state of this listener.""" - self.results.clear() - super().reset() - - -class TopologyEventListener(monitoring.TopologyListener): - def __init__(self): - self.results = defaultdict(list) - - def closed(self, event): - self.results["closed"].append(event) - - def description_changed(self, event): - self.results["description_changed"].append(event) - - def opened(self, event): - self.results["opened"].append(event) - - def reset(self): - """Reset the state of this listener.""" - self.results.clear() - - -class AllowListEventListener(EventListener): - def __init__(self, *commands): - self.commands = set(commands) - super().__init__() - - def started(self, event): - if event.command_name in self.commands: - super().started(event) - - def succeeded(self, event): - if event.command_name in self.commands: - super().succeeded(event) - - def failed(self, event): - if event.command_name in self.commands: - super().failed(event) - - -class OvertCommandListener(EventListener): - """A CommandListener that ignores sensitive commands.""" - - ignore_list_collections = False - - def started(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().started(event) - - def succeeded(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().succeeded(event) - - def failed(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().failed(event) - - -class _ServerEventListener: - """Listens to all events.""" - - def __init__(self): - self.results = [] - - def opened(self, event): - self.results.append(event) - - def description_changed(self, event): - self.results.append(event) - - def closed(self, event): - self.results.append(event) - - def matching(self, matcher): - """Return the matching events.""" - results = self.results[:] - return [event for event in results if matcher(event)] - - def reset(self): - self.results = [] - - -class ServerEventListener(_ServerEventListener, monitoring.ServerListener): - """Listens to Server events.""" - - -class ServerAndTopologyEventListener( # type: ignore[misc] - ServerEventListener, monitoring.TopologyListener -): - """Listens to Server and Topology events.""" - - -class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): - """Listens to only server heartbeat events.""" - - def started(self, event): - self.add_event(event) - - def succeeded(self, event): - self.add_event(event) - - def failed(self, event): - self.add_event(event) - - -class HeartbeatEventsListListener(HeartbeatEventListener): - """Listens to only server heartbeat events and publishes them to a provided list.""" - - def __init__(self, events): - super().__init__() - self.event_list = events - - def started(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatStartedEvent") - - def succeeded(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatSucceededEvent") - - def failed(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatFailedEvent") - - -class AsyncMockConnection: - def __init__(self): - self.cancel_context = _CancellationContext() - self.more_to_come = False - self.id = random.randint(0, 100) - - def close_conn(self, reason): - pass - - def __aenter__(self): - return self - - def __aexit__(self, exc_type, exc_val, exc_tb): - pass - - -class MockConnection: - def __init__(self): - self.cancel_context = _CancellationContext() - self.more_to_come = False - self.id = random.randint(0, 100) - - def close_conn(self, reason): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - -class AsyncMockPool: - def __init__(self, address, options, handshake=True, client_id=None): - self.gen = _PoolGeneration() - self._lock = _async_create_lock() - self.opts = options - self.operation_count = 0 - self.conns = [] - - def stale_generation(self, gen, service_id): - return self.gen.stale(gen, service_id) - - @contextlib.asynccontextmanager - async def checkout(self, handler=None): - yield AsyncMockConnection() - - async def checkin(self, *args, **kwargs): - pass - - async def _reset(self, service_id=None): - async with self._lock: - self.gen.inc(service_id) - - async def ready(self): - pass - - async def reset(self, service_id=None, interrupt_connections=False): - await self._reset() - - async def reset_without_pause(self): - await self._reset() - - async def close(self): - await self._reset() - - async def update_is_writable(self, is_writable): - pass - - async def remove_stale_sockets(self, *args, **kwargs): - pass - - -class MockPool: - def __init__(self, address, options, handshake=True, client_id=None): - self.gen = _PoolGeneration() - self._lock = _create_lock() - self.opts = options - self.operation_count = 0 - self.conns = [] - - def stale_generation(self, gen, service_id): - return self.gen.stale(gen, service_id) - - def checkout(self, handler=None): - return MockConnection() - - def checkin(self, *args, **kwargs): - pass - - def _reset(self, service_id=None): - with self._lock: - self.gen.inc(service_id) - - def ready(self): - pass - - def reset(self, service_id=None, interrupt_connections=False): - self._reset() - - def reset_without_pause(self): - self._reset() - - def close(self): - self._reset() - - def update_is_writable(self, is_writable): - pass - - def remove_stale_sockets(self, *args, **kwargs): - pass - - -class ScenarioDict(dict): - """Dict that returns {} for any unknown key, recursively.""" - - def __init__(self, data): - def convert(v): - if isinstance(v, abc.Mapping): - return ScenarioDict(v) - if isinstance(v, (str, bytes)): - return v - if isinstance(v, abc.Sequence): - return [convert(item) for item in v] - return v - - dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) +def get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = client._get_topology() + server = topology._select_server(writable_server_selector, _Op.TEST) + return server.pool - def __getitem__(self, item): - try: - return dict.__getitem__(self, item) - except KeyError: - # Unlike a defaultdict, don't set the key, just return a dict. - return ScenarioDict({}) +def get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in (client._get_topology()).select_servers(any_server_selector, _Op.TEST) + ] -class CompareType: - """Class that compares equal to any object of the given type(s).""" - def __init__(self, types): - self.types = types +def wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. - def __eq__(self, other): - return isinstance(other, self.types) + E.g.: + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') -class FunctionCallRecorder: - """Utility class to wrap a callable and record its invocations.""" + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). - def __init__(self, function): - self._function = function - self._call_list = [] + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = predicate() + else: + retval = predicate() + if retval: + return retval - def __call__(self, *args, **kwargs): - self._call_list.append((args, kwargs)) - return self._function(*args, **kwargs) + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) - def reset(self): - """Wipes the call list.""" - self._call_list = [] + time.sleep(interval) - def call_list(self): - """Returns a copy of the call list.""" - return self._call_list[:] - @property - def call_count(self): - """Returns the number of times the function has been called.""" - return len(self._call_list) +def is_mongos(client): + res = client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" def ensure_all_connected(client: MongoClient) -> None: @@ -511,231 +111,17 @@ def discover(): return connected_host_list try: - wait_until(lambda: target_host_list == discover(), "connected to all hosts") - except AssertionError as exc: - raise AssertionError( - f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" - ) - - -async def async_ensure_all_connected(client: AsyncMongoClient) -> None: - """Ensure that the client's connection pool has socket connections to all - members of a replica set. Raises ConfigurationError when called with a - non-replica set client. - Depending on the use-case, the caller may need to clear any event listeners - that are configured on the client. - """ - hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) - if "setName" not in hello: - raise ConfigurationError("cluster is not a replica set") - - target_host_list = set(hello["hosts"] + hello.get("passives", [])) - connected_host_list = {hello["me"]} + def predicate(): + return target_host_list == discover() - # Run hello until we have connected to each host at least once. - async def discover(): - i = 0 - while i < 100 and connected_host_list != target_host_list: - hello: dict = await client.admin.command( - HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY - ) - connected_host_list.update([hello["me"]]) - i += 1 - return connected_host_list - - try: - - async def predicate(): - return target_host_list == await discover() - - await async_wait_until(predicate, "connected to all hosts") + wait_until(predicate, "connected to all hosts") except AssertionError as exc: raise AssertionError( f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" ) -def one(s): - """Get one element of a set""" - return next(iter(s)) - - -def oid_generated_on_process(oid): - """Makes a determination as to whether the given ObjectId was generated - by the current process, based on the 5-byte random number in the ObjectId. - """ - return ObjectId._random() == oid.binary[4:9] - - -def delay(sec): - return """function() { sleep(%f * 1000); return true; }""" % sec - - -def get_command_line(client): - command_line = client.admin.command("getCmdLineOpts") - assert command_line["ok"] == 1, "getCmdLineOpts() failed" - return command_line - - -def camel_to_snake(camel): - # Regex to convert CamelCase to snake_case. - snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() - - -def camel_to_upper_camel(camel): - return camel[0].upper() + camel[1:] - - -def camel_to_snake_args(arguments): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - arguments[c2s] = arguments.pop(arg_name) - return arguments - - -def snake_to_camel(snake): - # Regex to convert snake_case to lowerCamelCase. - return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) - - -def parse_collection_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - - if "writeConcern" in opts: - opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) - - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - - if "timeoutMS" in opts: - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 - return opts - - -def server_started_with_option(client, cmdline_opt, config_opt): - """Check if the server was started with a particular option. - - :Parameters: - - `cmdline_opt`: The command line option (i.e. --nojournal) - - `config_opt`: The config file option (i.e. nojournal) - """ - command_line = get_command_line(client) - if "parsed" in command_line: - parsed = command_line["parsed"] - if config_opt in parsed: - return parsed[config_opt] - argv = command_line["argv"] - return cmdline_opt in argv - - -def server_started_with_auth(client): - try: - command_line = get_command_line(client) - except OperationFailure as e: - assert e.details is not None - msg = e.details.get("errmsg", "") - if e.code == 13 or "unauthorized" in msg or "login" in msg: - # Unauthorized. - return True - raise - - # MongoDB >= 2.0 - if "parsed" in command_line: - parsed = command_line["parsed"] - # MongoDB >= 2.6 - if "security" in parsed: - security = parsed["security"] - # >= rc3 - if "authorization" in security: - return security["authorization"] == "enabled" - # < rc3 - return security.get("auth", False) or bool(security.get("keyFile")) - return parsed.get("auth", False) or bool(parsed.get("keyFile")) - # Legacy - argv = command_line["argv"] - return "--auth" in argv or "--keyFile" in argv - - -def joinall(threads): - """Join threads with a 5-minute timeout, assert joins succeeded""" - for t in threads: - t.join(300) - assert not t.is_alive(), "Thread %s hung" % t - - -async def async_joinall(tasks): - """Join threads with a 5-minute timeout, assert joins succeeded""" - await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) - - -def wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - time.sleep(interval) - - -async def async_wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - if iscoroutinefunction(predicate): - retval = await predicate() - else: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - await asyncio.sleep(interval) - - -def is_mongos(client): - res = client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - -async def async_is_mongos(client): - res = await client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - def assertRaisesExactly(cls, fn, *args, **kwargs): """ Unlike the standard assertRaises, this checks that a function raises a @@ -750,347 +136,74 @@ def assertRaisesExactly(cls, fn, *args, **kwargs): raise AssertionError("%s not raised" % cls) -async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): - """ - Unlike the standard assertRaises, this checks that a function raises a - specific class of exception, and not a subclass. E.g., check that - MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. - """ - try: - await fn(*args, **kwargs) - except Exception as e: - assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" - else: - raise AssertionError("%s not raised" % cls) - - -@contextlib.contextmanager -def _ignore_deprecations(): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - yield - - -def ignore_deprecations(wrapped=None): - """A context manager or a decorator.""" - if wrapped: - if iscoroutinefunction(wrapped): - - @functools.wraps(wrapped) - async def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return await wrapped(*args, **kwargs) - else: - - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return wrapped(*args, **kwargs) +def set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + client.admin.command(cmd) - return wrapper +def joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t else: - return _ignore_deprecations() - - -class DeprecationFilter: - def __init__(self, action="ignore"): - """Start filtering deprecations.""" - self.warn_context = warnings.catch_warnings() - self.warn_context.__enter__() - warnings.simplefilter(action, DeprecationWarning) - - def stop(self): - """Stop filtering deprecations.""" - self.warn_context.__exit__() # type: ignore - self.warn_context = None # type: ignore + asyncio.wait([t.task for t in tasks if t is not None], timeout=300) -def get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = client._get_topology() - server = topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -async def async_get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = await client._get_topology() - server = await topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -def get_pools(client): - """Get all pools.""" - return [ - server.pool - for server in client._get_topology().select_servers(any_server_selector, _Op.TEST) - ] - - -async def async_get_pools(client): - """Get all pools.""" - return [ - server.pool - for server in await (await client._get_topology()).select_servers( - any_server_selector, _Op.TEST - ) - ] - - -# Constants for run_threads and lazy_client_trial. -NTRIALS = 5 -NTHREADS = 10 - - -def run_threads(collection, target): - """Run a target function in many threads. - - target is a function taking a Collection and an integer. - """ - threads = [] - for i in range(NTHREADS): - bound_target = partial(target, collection, i) - threads.append(threading.Thread(target=bound_target)) - - for t in threads: - t.start() - - for t in threads: - t.join(60) - assert not t.is_alive() - - -@contextlib.contextmanager -def frequent_thread_switches(): - """Make concurrency bugs more likely to manifest.""" - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - - try: - yield - finally: - sys.setswitchinterval(interval) - - -def lazy_client_trial(reset, target, test, get_client): - """Test concurrent operations on a lazily-connecting client. - - `reset` takes a collection and resets it for the next trial. - - `target` takes a lazily-connecting collection and an index from - 0 to NTHREADS, and performs some operation, e.g. an insert. - - `test` takes the lazily-connecting collection and asserts a - post-condition to prove `target` succeeded. - """ - collection = client_context.client.pymongo_test.test - - with frequent_thread_switches(): - for _i in range(NTRIALS): - reset(collection) - lazy_client = get_client() - lazy_collection = lazy_client.pymongo_test.test - run_threads(lazy_collection, target) - test(lazy_collection) - - -def gevent_monkey_patched(): - """Check if gevent's monkey patching is active.""" - try: - import socket - - import gevent.socket # type:ignore[import] - - return socket.socket is gevent.socket.socket - except ImportError: - return False - - -def eventlet_monkey_patched(): - """Check if eventlet's monkey patching is active.""" - import threading - - return threading.current_thread.__module__ == "eventlet.green.threading" - - -def is_greenthread_patched(): - return gevent_monkey_patched() or eventlet_monkey_patched() - - -def parse_read_preference(pref): - # Make first letter lowercase to match read_pref's modes. - mode_string = pref.get("mode", "primary") - mode_string = mode_string[:1].lower() + mode_string[1:] - mode = read_preferences.read_pref_mode_from_name(mode_string) - max_staleness = pref.get("maxStalenessSeconds", -1) - tag_sets = pref.get("tagSets") or pref.get("tag_sets") - return read_preferences.make_read_preference( - mode, tag_sets=tag_sets, max_staleness=max_staleness - ) - - -def server_name_to_type(name): - """Convert a ServerType name to the corresponding value. For SDAM tests.""" - # Special case, some tests in the spec include the PossiblePrimary - # type, but only single-threaded drivers need that type. We call - # possible primaries Unknown. - if name == "PossiblePrimary": - return SERVER_TYPE.Unknown - return getattr(SERVER_TYPE, name) - - -def cat_files(dest, *sources): - """Cat multiple files into dest.""" - with open(dest, "wb") as fdst: - for src in sources: - with open(src, "rb") as fsrc: - shutil.copyfileobj(fsrc, fdst) - - -@contextlib.contextmanager -def assertion_context(msg): - """A context manager that adds info to an assertion failure.""" - try: - yield - except AssertionError as exc: - raise AssertionError(f"{msg}: {exc}") - - -def parse_spec_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - - if "writeConcern" in opts: - w_opts = opts.pop("writeConcern") - if "journal" in w_opts: - w_opts["j"] = w_opts.pop("journal") - if "wtimeoutMS" in w_opts: - w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") - opts["write_concern"] = WriteConcern(**dict(w_opts)) - - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - - if "timeoutMS" in opts: - assert isinstance(opts["timeoutMS"], int) - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 - - if "maxTimeMS" in opts: - opts["max_time_ms"] = opts.pop("maxTimeMS") - - if "maxCommitTimeMS" in opts: - opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") - - return dict(opts) - - -def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - # Named "key" instead not fieldName. - if arg_name == "fieldName": - arguments["key"] = arguments.pop(arg_name) - # Aggregate uses "batchSize", while find uses batch_size. - elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": - continue - elif arg_name == "timeoutMode": - raise unittest.SkipTest("PyMongo does not support timeoutMode") - # Requires boolean returnDocument. - elif arg_name == "returnDocument": - arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) - elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): - # Parse each request into a bulk write model. - requests = [] - for request in arguments[c2s]: - if "name" in request: - # CRUD v2 format - bulk_model = camel_to_upper_camel(request["name"]) - bulk_class = getattr(operations, bulk_model) - bulk_arguments = camel_to_snake_args(request["arguments"]) - else: - # Unified test format - bulk_model, spec = next(iter(request.items())) - bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) - bulk_arguments = camel_to_snake_args(spec) - requests.append(bulk_class(**dict(bulk_arguments))) - arguments[c2s] = requests - elif arg_name == "session": - arguments["session"] = entity_map[arguments["session"]] - elif opname == "open_download_stream" and arg_name == "id": - arguments["file_id"] = arguments.pop(arg_name) - elif opname not in ("find", "find_one") and c2s == "max_time_ms": - # find is the only method that accepts snake_case max_time_ms. - # All other methods take kwargs which must use the server's - # camelCase maxTimeMS. See PYTHON-1855. - arguments["maxTimeMS"] = arguments.pop("max_time_ms") - elif opname == "with_transaction" and arg_name == "callback": - if "operations" in arguments[arg_name]: - # CRUD v2 format - callback_ops = arguments[arg_name]["operations"] - else: - # Unified test format - callback_ops = arguments[arg_name] - arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) - elif opname == "drop_collection" and arg_name == "collection": - arguments["name_or_collection"] = arguments.pop(arg_name) - elif opname == "create_collection": - if arg_name == "collection": - arguments["name"] = arguments.pop(arg_name) - arguments["check_exists"] = False - # Any other arguments to create_collection are passed through - # **kwargs. - elif opname == "create_index" and arg_name == "keys": - arguments["keys"] = list(arguments.pop(arg_name).items()) - elif opname == "drop_index" and arg_name == "name": - arguments["index_or_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "to": - arguments["new_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "dropTarget": - arguments["dropTarget"] = arguments.pop(arg_name) - elif arg_name == "cursorType": - cursor_type = arguments.pop(arg_name) - if cursor_type == "tailable": - arguments["cursor_type"] = CursorType.TAILABLE - elif cursor_type == "tailableAwait": - arguments["cursor_type"] = CursorType.TAILABLE - else: - raise AssertionError(f"Unsupported cursorType: {cursor_type}") - else: - arguments[c2s] = arguments.pop(arg_name) - +class MockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) -def set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - client.admin.command(cmd) + def close_conn(self, reason): + pass + def __enter__(self): + return self -async def async_set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - await client.admin.command(cmd) + def __exit__(self, exc_type, exc_val, exc_tb): + pass -def create_async_event(): - return asyncio.Event() +class MockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) -def create_event(): - return threading.Event() + @contextlib.contextmanager + def checkout(self, handler=None): + yield MockConnection() + def checkin(self, *args, **kwargs): + pass -def async_create_barrier(N_TASKS, timeout: float | None = None): - return asyncio.Barrier(N_TASKS) + def _reset(self, service_id=None): + with self._lock: + self.gen.inc(service_id) + def ready(self): + pass -def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout=timeout) + def reset(self, service_id=None, interrupt_connections=False): + self._reset() + def reset_without_pause(self): + self._reset() -async def async_barrier_wait(barrier, timeout: float | None = None): - await asyncio.wait_for(barrier.wait(), timeout=timeout) + def close(self): + self._reset() + def update_is_writable(self, is_writable): + pass -def barrier_wait(barrier, timeout: float | None = None): - barrier.wait() + def remove_stale_sockets(self, *args, **kwargs): + pass diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 9667ea701b..2772f06070 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -19,17 +19,18 @@ import os import sys from test import PyMongoTestCase +from test.utils import MockPool sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, parse_read_preference from test.utils_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, ) +from test.utils_shared import parse_read_preference from bson import json_util from pymongo.common import HEARTBEAT_FREQUENCY diff --git a/test/utils_shared.py b/test/utils_shared.py new file mode 100644 index 0000000000..2c52445968 --- /dev/null +++ b/test/utils_shared.py @@ -0,0 +1,705 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for testing pymongo""" +from __future__ import annotations + +import asyncio +import contextlib +import copy +import functools +import random +import re +import shutil +import sys +import threading +import unittest +import warnings +from asyncio import iscoroutinefunction +from collections import abc, defaultdict +from functools import partial +from test import client_context +from test.asynchronous.utils import async_wait_until +from test.utils import wait_until +from typing import List + +from bson.objectid import ObjectId +from pymongo import monitoring, operations, read_preferences +from pymongo.cursor_shared import CursorType +from pymongo.errors import OperationFailure +from pymongo.helpers_shared import _SENSITIVE_COMMANDS +from pymongo.lock import _async_create_lock, _create_lock +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_concern import ReadConcern +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration +from pymongo.write_concern import WriteConcern + +IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) + + +class BaseListener: + def __init__(self): + self.events = [] + + def reset(self): + self.events = [] + + def add_event(self, event): + self.events.append(event) + + def event_count(self, event_type): + return len(self.events_by_type(event_type)) + + def events_by_type(self, event_type): + """Return the matching events by event class. + + event_type can be a single class or a tuple of classes. + """ + return self.matching(lambda e: isinstance(e, event_type)) + + def matching(self, matcher): + """Return the matching events.""" + return [event for event in self.events[:] if matcher(event)] + + def wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + + +class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): + def connection_created(self, event): + assert isinstance(event, ConnectionCreatedEvent) + self.add_event(event) + + def connection_ready(self, event): + assert isinstance(event, ConnectionReadyEvent) + self.add_event(event) + + def connection_closed(self, event): + assert isinstance(event, ConnectionClosedEvent) + self.add_event(event) + + def connection_check_out_started(self, event): + assert isinstance(event, ConnectionCheckOutStartedEvent) + self.add_event(event) + + def connection_check_out_failed(self, event): + assert isinstance(event, ConnectionCheckOutFailedEvent) + self.add_event(event) + + def connection_checked_out(self, event): + assert isinstance(event, ConnectionCheckedOutEvent) + self.add_event(event) + + def connection_checked_in(self, event): + assert isinstance(event, ConnectionCheckedInEvent) + self.add_event(event) + + def pool_created(self, event): + assert isinstance(event, PoolCreatedEvent) + self.add_event(event) + + def pool_ready(self, event): + assert isinstance(event, PoolReadyEvent) + self.add_event(event) + + def pool_cleared(self, event): + assert isinstance(event, PoolClearedEvent) + self.add_event(event) + + def pool_closed(self, event): + assert isinstance(event, PoolClosedEvent) + self.add_event(event) + + +class EventListener(BaseListener, monitoring.CommandListener): + def __init__(self): + super().__init__() + self.results = defaultdict(list) + + @property + def started_events(self) -> List[monitoring.CommandStartedEvent]: + return self.results["started"] + + @property + def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: + return self.results["succeeded"] + + @property + def failed_events(self) -> List[monitoring.CommandFailedEvent]: + return self.results["failed"] + + def started(self, event: monitoring.CommandStartedEvent) -> None: + self.started_events.append(event) + self.add_event(event) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + self.succeeded_events.append(event) + self.add_event(event) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + self.failed_events.append(event) + self.add_event(event) + + def started_command_names(self) -> List[str]: + """Return list of command names started.""" + return [event.command_name for event in self.started_events] + + def reset(self) -> None: + """Reset the state of this listener.""" + self.results.clear() + super().reset() + + +class TopologyEventListener(monitoring.TopologyListener): + def __init__(self): + self.results = defaultdict(list) + + def closed(self, event): + self.results["closed"].append(event) + + def description_changed(self, event): + self.results["description_changed"].append(event) + + def opened(self, event): + self.results["opened"].append(event) + + def reset(self): + """Reset the state of this listener.""" + self.results.clear() + + +class AllowListEventListener(EventListener): + def __init__(self, *commands): + self.commands = set(commands) + super().__init__() + + def started(self, event): + if event.command_name in self.commands: + super().started(event) + + def succeeded(self, event): + if event.command_name in self.commands: + super().succeeded(event) + + def failed(self, event): + if event.command_name in self.commands: + super().failed(event) + + +class OvertCommandListener(EventListener): + """A CommandListener that ignores sensitive commands.""" + + ignore_list_collections = False + + def started(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().started(event) + + def succeeded(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().succeeded(event) + + def failed(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().failed(event) + + +class _ServerEventListener: + """Listens to all events.""" + + def __init__(self): + self.results = [] + + def opened(self, event): + self.results.append(event) + + def description_changed(self, event): + self.results.append(event) + + def closed(self, event): + self.results.append(event) + + def matching(self, matcher): + """Return the matching events.""" + results = self.results[:] + return [event for event in results if matcher(event)] + + def reset(self): + self.results = [] + + +class ServerEventListener(_ServerEventListener, monitoring.ServerListener): + """Listens to Server events.""" + + +class ServerAndTopologyEventListener( # type: ignore[misc] + ServerEventListener, monitoring.TopologyListener +): + """Listens to Server and Topology events.""" + + +class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): + """Listens to only server heartbeat events.""" + + def started(self, event): + self.add_event(event) + + def succeeded(self, event): + self.add_event(event) + + def failed(self, event): + self.add_event(event) + + +class HeartbeatEventsListListener(HeartbeatEventListener): + """Listens to only server heartbeat events and publishes them to a provided list.""" + + def __init__(self, events): + super().__init__() + self.event_list = events + + def started(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatStartedEvent") + + def succeeded(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatSucceededEvent") + + def failed(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatFailedEvent") + + +class ScenarioDict(dict): + """Dict that returns {} for any unknown key, recursively.""" + + def __init__(self, data): + def convert(v): + if isinstance(v, abc.Mapping): + return ScenarioDict(v) + if isinstance(v, (str, bytes)): + return v + if isinstance(v, abc.Sequence): + return [convert(item) for item in v] + return v + + dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) + + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError: + # Unlike a defaultdict, don't set the key, just return a dict. + return ScenarioDict({}) + + +class CompareType: + """Class that compares equal to any object of the given type(s).""" + + def __init__(self, types): + self.types = types + + def __eq__(self, other): + return isinstance(other, self.types) + + +class FunctionCallRecorder: + """Utility class to wrap a callable and record its invocations.""" + + def __init__(self, function): + self._function = function + self._call_list = [] + + def __call__(self, *args, **kwargs): + self._call_list.append((args, kwargs)) + if iscoroutinefunction(self._function): + return self._function(*args, **kwargs) + else: + return self._function(*args, **kwargs) + + def reset(self): + """Wipes the call list.""" + self._call_list = [] + + def call_list(self): + """Returns a copy of the call list.""" + return self._call_list[:] + + @property + def call_count(self): + """Returns the number of times the function has been called.""" + return len(self._call_list) + + +def one(s): + """Get one element of a set""" + return next(iter(s)) + + +def oid_generated_on_process(oid): + """Makes a determination as to whether the given ObjectId was generated + by the current process, based on the 5-byte random number in the ObjectId. + """ + return ObjectId._random() == oid.binary[4:9] + + +def delay(sec): + return """function() { sleep(%f * 1000); return true; }""" % sec + + +def camel_to_snake(camel): + # Regex to convert CamelCase to snake_case. + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() + + +def camel_to_upper_camel(camel): + return camel[0].upper() + camel[1:] + + +def camel_to_snake_args(arguments): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + arguments[c2s] = arguments.pop(arg_name) + return arguments + + +def snake_to_camel(snake): + # Regex to convert snake_case to lowerCamelCase. + return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) + + +def parse_collection_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + return opts + + +@contextlib.contextmanager +def _ignore_deprecations(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield + + +def ignore_deprecations(wrapped=None): + """A context manager or a decorator.""" + if wrapped: + if iscoroutinefunction(wrapped): + + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return await wrapped(*args, **kwargs) + else: + + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return wrapped(*args, **kwargs) + + return wrapper + + else: + return _ignore_deprecations() + + +class DeprecationFilter: + def __init__(self, action="ignore"): + """Start filtering deprecations.""" + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() + warnings.simplefilter(action, DeprecationWarning) + + def stop(self): + """Stop filtering deprecations.""" + self.warn_context.__exit__() # type: ignore + self.warn_context = None # type: ignore + + +# Constants for run_threads and lazy_client_trial. +NTRIALS = 5 +NTHREADS = 10 + + +def run_threads(collection, target): + """Run a target function in many threads. + + target is a function taking a Collection and an integer. + """ + threads = [] + for i in range(NTHREADS): + bound_target = partial(target, collection, i) + threads.append(threading.Thread(target=bound_target)) + + for t in threads: + t.start() + + for t in threads: + t.join(60) + assert not t.is_alive() + + +@contextlib.contextmanager +def frequent_thread_switches(): + """Make concurrency bugs more likely to manifest.""" + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + + try: + yield + finally: + sys.setswitchinterval(interval) + + +def lazy_client_trial(reset, target, test, get_client): + """Test concurrent operations on a lazily-connecting client. + + `reset` takes a collection and resets it for the next trial. + + `target` takes a lazily-connecting collection and an index from + 0 to NTHREADS, and performs some operation, e.g. an insert. + + `test` takes the lazily-connecting collection and asserts a + post-condition to prove `target` succeeded. + """ + collection = client_context.client.pymongo_test.test + + with frequent_thread_switches(): + for _i in range(NTRIALS): + reset(collection) + lazy_client = get_client() + lazy_collection = lazy_client.pymongo_test.test + run_threads(lazy_collection, target) + test(lazy_collection) + + +def gevent_monkey_patched(): + """Check if gevent's monkey patching is active.""" + try: + import socket + + import gevent.socket # type:ignore[import] + + return socket.socket is gevent.socket.socket + except ImportError: + return False + + +def eventlet_monkey_patched(): + """Check if eventlet's monkey patching is active.""" + import threading + + return threading.current_thread.__module__ == "eventlet.green.threading" + + +def is_greenthread_patched(): + return gevent_monkey_patched() or eventlet_monkey_patched() + + +def parse_read_preference(pref): + # Make first letter lowercase to match read_pref's modes. + mode_string = pref.get("mode", "primary") + mode_string = mode_string[:1].lower() + mode_string[1:] + mode = read_preferences.read_pref_mode_from_name(mode_string) + max_staleness = pref.get("maxStalenessSeconds", -1) + tag_sets = pref.get("tagSets") or pref.get("tag_sets") + return read_preferences.make_read_preference( + mode, tag_sets=tag_sets, max_staleness=max_staleness + ) + + +def server_name_to_type(name): + """Convert a ServerType name to the corresponding value. For SDAM tests.""" + # Special case, some tests in the spec include the PossiblePrimary + # type, but only single-threaded drivers need that type. We call + # possible primaries Unknown. + if name == "PossiblePrimary": + return SERVER_TYPE.Unknown + return getattr(SERVER_TYPE, name) + + +def cat_files(dest, *sources): + """Cat multiple files into dest.""" + with open(dest, "wb") as fdst: + for src in sources: + with open(src, "rb") as fsrc: + shutil.copyfileobj(fsrc, fdst) + + +@contextlib.contextmanager +def assertion_context(msg): + """A context manager that adds info to an assertion failure.""" + try: + yield + except AssertionError as exc: + raise AssertionError(f"{msg}: {exc}") + + +def parse_spec_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + w_opts = opts.pop("writeConcern") + if "journal" in w_opts: + w_opts["j"] = w_opts.pop("journal") + if "wtimeoutMS" in w_opts: + w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") + opts["write_concern"] = WriteConcern(**dict(w_opts)) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + assert isinstance(opts["timeoutMS"], int) + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + + if "maxTimeMS" in opts: + opts["max_time_ms"] = opts.pop("maxTimeMS") + + if "maxCommitTimeMS" in opts: + opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") + + return dict(opts) + + +def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + # Named "key" instead not fieldName. + if arg_name == "fieldName": + arguments["key"] = arguments.pop(arg_name) + # Aggregate uses "batchSize", while find uses batch_size. + elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": + continue + elif arg_name == "timeoutMode": + raise unittest.SkipTest("PyMongo does not support timeoutMode") + # Requires boolean returnDocument. + elif arg_name == "returnDocument": + arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) + elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): + # Parse each request into a bulk write model. + requests = [] + for request in arguments[c2s]: + if "name" in request: + # CRUD v2 format + bulk_model = camel_to_upper_camel(request["name"]) + bulk_class = getattr(operations, bulk_model) + bulk_arguments = camel_to_snake_args(request["arguments"]) + else: + # Unified test format + bulk_model, spec = next(iter(request.items())) + bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) + bulk_arguments = camel_to_snake_args(spec) + requests.append(bulk_class(**dict(bulk_arguments))) + arguments[c2s] = requests + elif arg_name == "session": + arguments["session"] = entity_map[arguments["session"]] + elif opname == "open_download_stream" and arg_name == "id": + arguments["file_id"] = arguments.pop(arg_name) + elif opname not in ("find", "find_one") and c2s == "max_time_ms": + # find is the only method that accepts snake_case max_time_ms. + # All other methods take kwargs which must use the server's + # camelCase maxTimeMS. See PYTHON-1855. + arguments["maxTimeMS"] = arguments.pop("max_time_ms") + elif opname == "with_transaction" and arg_name == "callback": + if "operations" in arguments[arg_name]: + # CRUD v2 format + callback_ops = arguments[arg_name]["operations"] + else: + # Unified test format + callback_ops = arguments[arg_name] + arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) + elif opname == "drop_collection" and arg_name == "collection": + arguments["name_or_collection"] = arguments.pop(arg_name) + elif opname == "create_collection": + if arg_name == "collection": + arguments["name"] = arguments.pop(arg_name) + arguments["check_exists"] = False + # Any other arguments to create_collection are passed through + # **kwargs. + elif opname == "create_index" and arg_name == "keys": + arguments["keys"] = list(arguments.pop(arg_name).items()) + elif opname == "drop_index" and arg_name == "name": + arguments["index_or_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "to": + arguments["new_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "dropTarget": + arguments["dropTarget"] = arguments.pop(arg_name) + elif arg_name == "cursorType": + cursor_type = arguments.pop(arg_name) + if cursor_type == "tailable": + arguments["cursor_type"] = CursorType.TAILABLE + elif cursor_type == "tailableAwait": + arguments["cursor_type"] = CursorType.TAILABLE + else: + raise AssertionError(f"Unsupported cursorType: {cursor_type}") + else: + arguments[c2s] = arguments.pop(arg_name) + + +def create_async_event(): + return asyncio.Event() + + +def create_event(): + return threading.Event() + + +def async_create_barrier(n_tasks: int): + return asyncio.Barrier(n_tasks) + + +def create_barrier(n_tasks: int, timeout: float | None = None): + return threading.Barrier(n_tasks, timeout=timeout) + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout=timeout) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait(timeout=timeout) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index fe0ba6eb44..580e7cc120 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -24,7 +24,7 @@ from collections import abc from test import IntegrationTest, client_context, client_knobs from test.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_shared import ( CMAPListener, CompareType, EventListener, diff --git a/tools/synchro.py b/tools/synchro.py index 42d5694f47..e65270733e 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -259,6 +259,7 @@ def async_only_test(f: str) -> bool: "test_versioned_api_integration.py", "unified_format.py", "utils_selection_tests.py", + "utils.py", ]