From ead780a793f6187bcf8b25a3a27dffce48653074 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 7 Mar 2025 10:02:28 -0800 Subject: [PATCH 01/56] WIP (not cleaned up) --- pymongo/asynchronous/mongo_client.py | 158 ++++++++++++++++++++++++--- pymongo/synchronous/mongo_client.py | 158 ++++++++++++++++++++++++--- pymongo/uri_parser.py | 130 +++++++++++++++++++--- test/asynchronous/test_client.py | 3 + test/test_client.py | 3 + 5 files changed, 404 insertions(+), 48 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 37be9a194c..28f881cea0 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -742,13 +742,15 @@ def __init__( **kwargs, } - if host is None: - host = self.HOST - if isinstance(host, str): - host = [host] - if port is None: - port = self.PORT - if not isinstance(port, int): + self._host = host + self._port = port + if self._host is None: + self._host = self.HOST + if isinstance(self._host, str): + self._host = [self._host] + if self._port is None: + self._port = self.PORT + if not isinstance(self._port, int): raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep @@ -769,26 +771,19 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) res = uri_parser.parse_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) @@ -799,7 +794,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) + seeds.update(uri_parser.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -895,6 +890,134 @@ def __init__( # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self + self._for_resolve_uri = { + "username": username, + "password": password, + "srv_service_name": srv_service_name, + "srv_max_hosts": srv_max_hosts, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } + + def _resolve_uri(self): + keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) + for i in [ + "_pool_class", + "_monitor_class", + "_condition_class", + "host", + "port", + "type_registry", + ]: + keyword_opts.pop(i, None) + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = uri_parser.parse_uri_lookups( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(uri_parser.split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + + # Override connection string options with kwarg options. + opts.update(keyword_opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._for_resolve_uri["username"]) + password = opts.get("password", self._for_resolve_uri["password"]) + self._options = ClientOptions( + username, password, self._default_database_name, opts, _IS_SYNC + ) + + self._event_listeners = self._options.pool_options._event_listeners + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._for_resolve_uri["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._for_resolve_uri["monitor_class"], + condition_class=self._for_resolve_uri["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._for_resolve_uri["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) + + self._topology = Topology(self._topology_settings) + async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" await self._get_topology() @@ -1582,6 +1705,7 @@ async def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: + self._resolve_uri() await self._topology.open() async with self._lock: self._kill_cursors_executor.open() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 373deabd4e..bbf2ce88f8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -740,13 +740,15 @@ def __init__( **kwargs, } - if host is None: - host = self.HOST - if isinstance(host, str): - host = [host] - if port is None: - port = self.PORT - if not isinstance(port, int): + self._host = host + self._port = port + if self._host is None: + self._host = self.HOST + if isinstance(self._host, str): + self._host = [self._host] + if self._port is None: + self._port = self.PORT + if not isinstance(self._port, int): raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep @@ -767,26 +769,19 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) res = uri_parser.parse_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) @@ -797,7 +792,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) + seeds.update(uri_parser.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -893,6 +888,134 @@ def __init__( # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self + self._for_resolve_uri = { + "username": username, + "password": password, + "srv_service_name": srv_service_name, + "srv_max_hosts": srv_max_hosts, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } + + def _resolve_uri(self): + keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) + for i in [ + "_pool_class", + "_monitor_class", + "_condition_class", + "host", + "port", + "type_registry", + ]: + keyword_opts.pop(i, None) + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = uri_parser.parse_uri_lookups( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(uri_parser.split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + + # Override connection string options with kwarg options. + opts.update(keyword_opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._for_resolve_uri["username"]) + password = opts.get("password", self._for_resolve_uri["password"]) + self._options = ClientOptions( + username, password, self._default_database_name, opts, _IS_SYNC + ) + + self._event_listeners = self._options.pool_options._event_listeners + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._for_resolve_uri["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._for_resolve_uri["monitor_class"], + condition_class=self._for_resolve_uri["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._for_resolve_uri["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) + + self._topology = Topology(self._topology_settings) + def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" self._get_topology() @@ -1576,6 +1699,7 @@ def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: + self._resolve_uri() self._topology.open() with self._lock: self._kill_cursors_executor.open() diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 8f56ae4093..10d83489a4 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -427,7 +427,6 @@ def parse_uri( validate: bool = True, warn: bool = False, normalize: bool = True, - connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, ) -> dict[str, Any]: @@ -550,6 +549,122 @@ def parse_uri( fqdn, port = nodes[0] if port is not None: raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +def parse_uri_lookups( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. @@ -572,14 +687,6 @@ def parse_uri( raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") if "tls" not in options and "ssl" not in options: options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) else: nodes = split_hosts(hosts, default_port=default_port) @@ -587,12 +694,7 @@ def parse_uri( return { "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, "options": options, - "fqdn": fqdn, } diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 744a170be2..bc112feedb 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1893,12 +1893,15 @@ async def test_service_name_from_kwargs(self): async def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + await client.aconnect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( diff --git a/test/test_client.py b/test/test_client.py index cdc7691c28..4003e830a0 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1850,12 +1850,15 @@ def test_service_name_from_kwargs(self): def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( From 79c09eae1dd23c3f1b969e58466d533da36a0dc4 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 09:29:40 -0700 Subject: [PATCH 02/56] this might be broken? unsure.... --- pymongo/asynchronous/mongo_client.py | 34 +++++++++++--------- pymongo/synchronous/mongo_client.py | 34 +++++++++++--------- pymongo/uri_parser.py | 3 -- test/asynchronous/test_dns.py | 46 ++++++++++++---------------- test/test_dns.py | 46 ++++++++++++---------------- 5 files changed, 80 insertions(+), 83 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 28f881cea0..7401840a03 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -784,7 +784,6 @@ def __init__( validate=True, warn=True, normalize=False, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) seeds.update(res["nodelist"]) @@ -836,8 +835,6 @@ def __init__( _check_options(seeds, opts) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", username) - password = opts.get("password", password) self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase @@ -875,6 +872,15 @@ def __init__( self._closed = False self._init_background() + self._for_resolve_uri = { + "username": username, + "password": password, + "dbase": dbase, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -890,16 +896,16 @@ def __init__( # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self - self._for_resolve_uri = { - "username": username, - "password": password, - "srv_service_name": srv_service_name, - "srv_max_hosts": srv_max_hosts, - "fqdn": fqdn, - "pool_class": pool_class, - "monitor_class": monitor_class, - "condition_class": condition_class, - } + # self._for_resolve_uri = { + # "username": username, + # "password": password, + # "srv_service_name": srv_service_name, + # "srv_max_hosts": srv_max_hosts, + # "fqdn": fqdn, + # "pool_class": pool_class, + # "monitor_class": monitor_class, + # "condition_class": condition_class, + # } def _resolve_uri(self): keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) @@ -986,7 +992,7 @@ def _resolve_uri(self): username = opts.get("username", self._for_resolve_uri["username"]) password = opts.get("password", self._for_resolve_uri["password"]) self._options = ClientOptions( - username, password, self._default_database_name, opts, _IS_SYNC + username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index bbf2ce88f8..7dcaacdcd5 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -782,7 +782,6 @@ def __init__( validate=True, warn=True, normalize=False, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) seeds.update(res["nodelist"]) @@ -834,8 +833,6 @@ def __init__( _check_options(seeds, opts) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", username) - password = opts.get("password", password) self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase @@ -873,6 +870,15 @@ def __init__( self._closed = False self._init_background() + self._for_resolve_uri = { + "username": username, + "password": password, + "dbase": dbase, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -888,16 +894,16 @@ def __init__( # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self - self._for_resolve_uri = { - "username": username, - "password": password, - "srv_service_name": srv_service_name, - "srv_max_hosts": srv_max_hosts, - "fqdn": fqdn, - "pool_class": pool_class, - "monitor_class": monitor_class, - "condition_class": condition_class, - } + # self._for_resolve_uri = { + # "username": username, + # "password": password, + # "srv_service_name": srv_service_name, + # "srv_max_hosts": srv_max_hosts, + # "fqdn": fqdn, + # "pool_class": pool_class, + # "monitor_class": monitor_class, + # "condition_class": condition_class, + # } def _resolve_uri(self): keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) @@ -984,7 +990,7 @@ def _resolve_uri(self): username = opts.get("username", self._for_resolve_uri["username"]) password = opts.get("password", self._for_resolve_uri["password"]) self._options = ClientOptions( - username, password, self._default_database_name, opts, _IS_SYNC + username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 10d83489a4..a431dc4291 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -427,7 +427,6 @@ def parse_uri( validate: bool = True, warn: bool = False, normalize: bool = True, - srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -526,8 +525,6 @@ def parse_uri( if opts: options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) if "@" in host_part: userinfo, _, hosts = host_part.rpartition("@") user, passwd = parse_userinfo(userinfo) diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index e24e0fb5ce..f160500794 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -33,7 +33,7 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser import parse_uri, parse_uri_lookups, split_hosts _IS_SYNC = False @@ -110,6 +110,7 @@ async def run_test(self): if seeds or num_seeds: result = parse_uri(uri, validate=True) + result.update(parse_uri_lookups(uri, validate=True)) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: @@ -141,13 +142,16 @@ async def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True - + print(uri) + print(copts) client = self.simple_client(uri, **copts) + await client.aconnect() if client._options.connect: await client.aconnect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: + print(client.nodes) await async_wait_until( lambda: hosts == client.nodes, "match test hosts to client nodes" ) @@ -162,6 +166,7 @@ async def run_test(self): else: try: parse_uri(uri) + parse_uri_lookups(uri) except (ConfigurationError, ValueError): pass else: @@ -185,35 +190,24 @@ def create_tests(cls): class TestParsingErrors(AsyncPyMongoTestCase): async def test_invalid_host(self): - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb is not", - self.simple_client, - "mongodb+srv://mongodb", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb.com is not", - self.simple_client, - "mongodb+srv://mongodb.com", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://127.0.0.1", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://[::1]", - ) + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): + client = self.simple_client("mongodb+srv://mongodb") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): + client = self.simple_client("mongodb+srv://mongodb.com") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + await client.aconnect() class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): async def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + await client.aconnect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/test_dns.py b/test/test_dns.py index 6f4736fd5e..0a6408b402 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -33,7 +33,7 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser import parse_uri, parse_uri_lookups, split_hosts _IS_SYNC = True @@ -110,6 +110,7 @@ def run_test(self): if seeds or num_seeds: result = parse_uri(uri, validate=True) + result.update(parse_uri_lookups(uri, validate=True)) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: @@ -141,13 +142,16 @@ def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True - + print(uri) + print(copts) client = self.simple_client(uri, **copts) + client._connect() if client._options.connect: client._connect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: + print(client.nodes) wait_until(lambda: hosts == client.nodes, "match test hosts to client nodes") if num_hosts is not None: wait_until( @@ -160,6 +164,7 @@ def run_test(self): else: try: parse_uri(uri) + parse_uri_lookups(uri) except (ConfigurationError, ValueError): pass else: @@ -183,35 +188,24 @@ def create_tests(cls): class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb is not", - self.simple_client, - "mongodb+srv://mongodb", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb.com is not", - self.simple_client, - "mongodb+srv://mongodb.com", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://127.0.0.1", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://[::1]", - ) + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): + client = self.simple_client("mongodb+srv://mongodb") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): + client = self.simple_client("mongodb+srv://mongodb.com") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + client._connect() class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) From 7d771cb75377304061ebb0349137e4fe7244088d Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 16:53:01 -0700 Subject: [PATCH 03/56] keep parse_uri as is and have it call two different functions instead --- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/synchronous/mongo_client.py | 4 ++-- pymongo/uri_parser.py | 28 +++++++++++++++++++++++++++- test/asynchronous/test_dns.py | 8 +------- test/test_dns.py | 8 +------- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bc7d1f9b55..047f39af34 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -778,7 +778,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = uri_parser.parse_uri( + res = uri_parser._validate_uri( entity, port, validate=True, @@ -933,7 +933,7 @@ def _resolve_uri(self): timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout ) - res = uri_parser.parse_uri_lookups( + res = uri_parser._lookup_uri( entity, self._port, validate=True, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index e522744171..d4565815bd 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -776,7 +776,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = uri_parser.parse_uri( + res = uri_parser._validate_uri( entity, port, validate=True, @@ -931,7 +931,7 @@ def _resolve_uri(self): timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout ) - res = uri_parser.parse_uri_lookups( + res = uri_parser._lookup_uri( entity, self._port, validate=True, diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index a431dc4291..b6f105a6f3 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -427,6 +427,8 @@ def parse_uri( validate: bool = True, warn: bool = False, normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -482,6 +484,30 @@ def parse_uri( .. versionchanged:: 3.1 ``warn`` added so invalid options can be ignored. """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + _lookup_uri( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +): if uri.startswith(SCHEME): is_srv = False scheme_free = uri[SCHEME_LEN:] @@ -570,7 +596,7 @@ def parse_uri( } -def parse_uri_lookups( +def _lookup_uri( uri: str, default_port: Optional[int] = DEFAULT_PORT, validate: bool = True, diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index f160500794..65e4454a1e 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -33,7 +33,7 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, parse_uri_lookups, split_hosts +from pymongo.uri_parser import parse_uri, split_hosts _IS_SYNC = False @@ -110,7 +110,6 @@ async def run_test(self): if seeds or num_seeds: result = parse_uri(uri, validate=True) - result.update(parse_uri_lookups(uri, validate=True)) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: @@ -142,16 +141,12 @@ async def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True - print(uri) - print(copts) client = self.simple_client(uri, **copts) - await client.aconnect() if client._options.connect: await client.aconnect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: - print(client.nodes) await async_wait_until( lambda: hosts == client.nodes, "match test hosts to client nodes" ) @@ -166,7 +161,6 @@ async def run_test(self): else: try: parse_uri(uri) - parse_uri_lookups(uri) except (ConfigurationError, ValueError): pass else: diff --git a/test/test_dns.py b/test/test_dns.py index 0a6408b402..8bb8dd5e07 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -33,7 +33,7 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, parse_uri_lookups, split_hosts +from pymongo.uri_parser import parse_uri, split_hosts _IS_SYNC = True @@ -110,7 +110,6 @@ def run_test(self): if seeds or num_seeds: result = parse_uri(uri, validate=True) - result.update(parse_uri_lookups(uri, validate=True)) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: @@ -142,16 +141,12 @@ def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True - print(uri) - print(copts) client = self.simple_client(uri, **copts) - client._connect() if client._options.connect: client._connect() if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: - print(client.nodes) wait_until(lambda: hosts == client.nodes, "match test hosts to client nodes") if num_hosts is not None: wait_until( @@ -164,7 +159,6 @@ def run_test(self): else: try: parse_uri(uri) - parse_uri_lookups(uri) except (ConfigurationError, ValueError): pass else: From 0f64689a78a0e178f621d673b760db26d85fb77e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 16:55:34 -0700 Subject: [PATCH 04/56] cleanup --- pymongo/asynchronous/mongo_client.py | 11 ----------- pymongo/synchronous/mongo_client.py | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 047f39af34..1a9e7d3514 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -896,17 +896,6 @@ def __init__( # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self - # self._for_resolve_uri = { - # "username": username, - # "password": password, - # "srv_service_name": srv_service_name, - # "srv_max_hosts": srv_max_hosts, - # "fqdn": fqdn, - # "pool_class": pool_class, - # "monitor_class": monitor_class, - # "condition_class": condition_class, - # } - def _resolve_uri(self): keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) for i in [ diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index d4565815bd..5f58354fcf 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -894,17 +894,6 @@ def __init__( # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self - # self._for_resolve_uri = { - # "username": username, - # "password": password, - # "srv_service_name": srv_service_name, - # "srv_max_hosts": srv_max_hosts, - # "fqdn": fqdn, - # "pool_class": pool_class, - # "monitor_class": monitor_class, - # "condition_class": condition_class, - # } - def _resolve_uri(self): keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) for i in [ From ed50141524760b916b1e84978d2a37b9aa27f4ff Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 17:33:13 -0700 Subject: [PATCH 05/56] some refactoring to reduce code duplication --- pymongo/asynchronous/mongo_client.py | 105 +++++++++++++-------------- pymongo/synchronous/mongo_client.py | 105 +++++++++++++-------------- 2 files changed, 102 insertions(+), 108 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1a9e7d3514..8b77575df1 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -762,6 +762,7 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_uri_info = {"keyword_opts": keyword_opts} seeds = set() username = None @@ -814,25 +815,13 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) @@ -872,15 +861,17 @@ def __init__( self._closed = False self._init_background() - self._for_resolve_uri = { - "username": username, - "password": password, - "dbase": dbase, - "fqdn": fqdn, - "pool_class": pool_class, - "monitor_class": monitor_class, - "condition_class": condition_class, - } + self._resolve_uri_info.update( + { + "username": username, + "password": password, + "dbase": dbase, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } + ) if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -896,17 +887,16 @@ def __init__( # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options(self, opts, seeds): + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + def _resolve_uri(self): - keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) - for i in [ - "_pool_class", - "_monitor_class", - "_condition_class", - "host", - "port", - "type_registry", - ]: - keyword_opts.pop(i, None) + keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") @@ -957,31 +947,19 @@ def _resolve_uri(self): keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_opts(opts, seeds) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", self._for_resolve_uri["username"]) - password = opts.get("password", self._for_resolve_uri["password"]) + username = opts.get("username", self._resolve_uri_info["username"]) + password = opts.get("password", self._resolve_uri_info["password"]) self._options = ClientOptions( - username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC + username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners @@ -995,15 +973,15 @@ def _resolve_uri(self): self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, - pool_class=self._for_resolve_uri["pool_class"], + pool_class=self._resolve_uri_info["pool_class"], pool_options=self._options.pool_options, - monitor_class=self._for_resolve_uri["monitor_class"], - condition_class=self._for_resolve_uri["condition_class"], + monitor_class=self._resolve_uri_info["monitor_class"], + condition_class=self._resolve_uri_info["condition_class"], local_threshold_ms=self._options.local_threshold_ms, server_selection_timeout=self._options.server_selection_timeout, server_selector=self._options.server_selector, heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._for_resolve_uri["fqdn"], + fqdn=self._resolve_uri_info["fqdn"], direct_connection=self._options.direct_connection, load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, @@ -1013,6 +991,25 @@ def _resolve_uri(self): self._topology = Topology(self._topology_settings) + def _normalize_and_validate_opts(self, opts, seeds): + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts(self, keyword_opts, opts): + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts + async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" await self._get_topology() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5f58354fcf..4615277aeb 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -760,6 +760,7 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_uri_info = {"keyword_opts": keyword_opts} seeds = set() username = None @@ -812,25 +813,13 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) @@ -870,15 +859,17 @@ def __init__( self._closed = False self._init_background() - self._for_resolve_uri = { - "username": username, - "password": password, - "dbase": dbase, - "fqdn": fqdn, - "pool_class": pool_class, - "monitor_class": monitor_class, - "condition_class": condition_class, - } + self._resolve_uri_info.update( + { + "username": username, + "password": password, + "dbase": dbase, + "fqdn": fqdn, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } + ) if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -894,17 +885,16 @@ def __init__( # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options(self, opts, seeds): + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + def _resolve_uri(self): - keyword_opts = common._CaseInsensitiveDictionary(self._init_kwargs) - for i in [ - "_pool_class", - "_monitor_class", - "_condition_class", - "host", - "port", - "type_registry", - ]: - keyword_opts.pop(i, None) + keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") @@ -955,31 +945,19 @@ def _resolve_uri(self): keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_opts(opts, seeds) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", self._for_resolve_uri["username"]) - password = opts.get("password", self._for_resolve_uri["password"]) + username = opts.get("username", self._resolve_uri_info["username"]) + password = opts.get("password", self._resolve_uri_info["password"]) self._options = ClientOptions( - username, password, self._for_resolve_uri["dbase"], opts, _IS_SYNC + username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners @@ -993,15 +971,15 @@ def _resolve_uri(self): self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, - pool_class=self._for_resolve_uri["pool_class"], + pool_class=self._resolve_uri_info["pool_class"], pool_options=self._options.pool_options, - monitor_class=self._for_resolve_uri["monitor_class"], - condition_class=self._for_resolve_uri["condition_class"], + monitor_class=self._resolve_uri_info["monitor_class"], + condition_class=self._resolve_uri_info["condition_class"], local_threshold_ms=self._options.local_threshold_ms, server_selection_timeout=self._options.server_selection_timeout, server_selector=self._options.server_selector, heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._for_resolve_uri["fqdn"], + fqdn=self._resolve_uri_info["fqdn"], direct_connection=self._options.direct_connection, load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, @@ -1011,6 +989,25 @@ def _resolve_uri(self): self._topology = Topology(self._topology_settings) + def _normalize_and_validate_opts(self, opts, seeds): + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts(self, keyword_opts, opts): + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts + def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" self._get_topology() From ed25867689c7dea0cf77230bec7fe0e9dd8374c4 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 18:28:39 -0700 Subject: [PATCH 06/56] fix typing --- pymongo/asynchronous/mongo_client.py | 42 ++++++++++++++-------------- pymongo/synchronous/mongo_client.py | 42 ++++++++++++++-------------- pymongo/uri_parser.py | 2 +- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 8b77575df1..ff4de399b8 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -742,16 +742,16 @@ def __init__( **kwargs, } + if host is None: + host = self.HOST + if isinstance(host, str): + host = [host] + if port is None: + port = self.PORT + if not isinstance(port, int): + raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - if self._host is None: - self._host = self.HOST - if isinstance(self._host, str): - self._host = [self._host] - if self._port is None: - self._port = self.PORT - if not isinstance(self._port, int): - raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -762,7 +762,7 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class - self._resolve_uri_info = {"keyword_opts": keyword_opts} + self._resolve_uri_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() username = None @@ -824,6 +824,8 @@ def __init__( opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. + username = opts.get("username", username) + password = opts.get("password", password) self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase @@ -887,15 +889,7 @@ def __init__( # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self - def _normalize_and_validate_options(self, opts, seeds): - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) - return opts - - def _resolve_uri(self): + def _resolve_uri(self) -> None: keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() @@ -953,7 +947,7 @@ def _resolve_uri(self): srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - opts = self._normalize_and_validate_opts(opts, seeds) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", self._resolve_uri_info["username"]) @@ -991,7 +985,9 @@ def _resolve_uri(self): self._topology = Topology(self._topology_settings) - def _normalize_and_validate_opts(self, opts, seeds): + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: # Handle security-option conflicts in combined options. opts = _handle_security_options(opts) # Normalize combined options. @@ -999,7 +995,11 @@ def _normalize_and_validate_opts(self, opts, seeds): _check_options(seeds, opts) return opts - def _validate_kwargs_and_update_opts(self, keyword_opts, opts): + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: # Handle deprecated options in kwarg options. keyword_opts = _handle_option_deprecations(keyword_opts) # Validate kwarg options. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 4615277aeb..c8b05445c9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -740,16 +740,16 @@ def __init__( **kwargs, } + if host is None: + host = self.HOST + if isinstance(host, str): + host = [host] + if port is None: + port = self.PORT + if not isinstance(port, int): + raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - if self._host is None: - self._host = self.HOST - if isinstance(self._host, str): - self._host = [self._host] - if self._port is None: - self._port = self.PORT - if not isinstance(self._port, int): - raise TypeError(f"port must be an instance of int, not {type(port)}") # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -760,7 +760,7 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class - self._resolve_uri_info = {"keyword_opts": keyword_opts} + self._resolve_uri_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() username = None @@ -822,6 +822,8 @@ def __init__( opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. + username = opts.get("username", username) + password = opts.get("password", password) self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase @@ -885,15 +887,7 @@ def __init__( # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self - def _normalize_and_validate_options(self, opts, seeds): - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) - return opts - - def _resolve_uri(self): + def _resolve_uri(self) -> None: keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() @@ -951,7 +945,7 @@ def _resolve_uri(self): srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - opts = self._normalize_and_validate_opts(opts, seeds) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", self._resolve_uri_info["username"]) @@ -989,7 +983,9 @@ def _resolve_uri(self): self._topology = Topology(self._topology_settings) - def _normalize_and_validate_opts(self, opts, seeds): + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: # Handle security-option conflicts in combined options. opts = _handle_security_options(opts) # Normalize combined options. @@ -997,7 +993,11 @@ def _normalize_and_validate_opts(self, opts, seeds): _check_options(seeds, opts) return opts - def _validate_kwargs_and_update_opts(self, keyword_opts, opts): + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: # Handle deprecated options in kwarg options. keyword_opts = _handle_option_deprecations(keyword_opts) # Validate kwarg options. diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index b6f105a6f3..5f109c663c 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -507,7 +507,7 @@ def _validate_uri( warn: bool = False, normalize: bool = True, srv_max_hosts: Optional[int] = None, -): +) -> dict[str, Any]: if uri.startswith(SCHEME): is_srv = False scheme_free = uri[SCHEME_LEN:] From 8d48f44e1774ce3a69d66e60f169bedce9551c92 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Mar 2025 18:44:51 -0700 Subject: [PATCH 07/56] remove copied doc string --- pymongo/uri_parser.py | 53 ------------------------------------------- 1 file changed, 53 deletions(-) diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 5f109c663c..34f7d8dca1 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -606,59 +606,6 @@ def _lookup_uri( srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, ) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ if uri.startswith(SCHEME): is_srv = False scheme_free = uri[SCHEME_LEN:] From 1a3efed957958897f7f159aa3100c83d5a3c8882 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 08:52:50 -0700 Subject: [PATCH 08/56] move init_background to only be called upon client connection --- pymongo/asynchronous/mongo_client.py | 51 +++++++++++++++------------- pymongo/synchronous/mongo_client.py | 51 +++++++++++++++------------- test/asynchronous/test_client.py | 8 ++--- test/test_client.py | 8 ++--- 4 files changed, 60 insertions(+), 58 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index ff4de399b8..0aebd40199 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -859,9 +859,10 @@ def __init__( server_monitoring_mode=options.server_monitoring_mode, ) + self._topology = Topology(self._topology_settings) + self._opened = False self._closed = False - self._init_background() self._resolve_uri_info.update( { @@ -878,17 +879,8 @@ def __init__( self._get_topology() # type: ignore[unused-coroutine] self._encrypter = None - if self._options.auto_encryption_opts: - from pymongo.asynchronous.encryption import _Encrypter - - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - AsyncMongoClient._clients[self._topology._topology_id] = self - def _resolve_uri(self) -> None: keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() @@ -985,6 +977,17 @@ def _resolve_uri(self) -> None: self._topology = Topology(self._topology_settings) + if self._options.auto_encryption_opts: + from pymongo.asynchronous.encryption import _Encrypter + + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) + self._timeout = self._options.timeout + + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + AsyncMongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] ) -> common._CaseInsensitiveDictionary: @@ -1672,19 +1675,20 @@ async def close(self) -> None: await self._end_sessions(session_ids) # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. - self._kill_cursors_executor.close() - await self._process_kill_cursors() - await self._topology.close() - if self._encrypter: - # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. - await self._encrypter.close() - self._closed = True - if not _IS_SYNC: - await asyncio.gather( - self._topology.cleanup_monitors(), # type: ignore[func-returns-value] - self._kill_cursors_executor.join(), # type: ignore[func-returns-value] - return_exceptions=True, - ) + if self._opened: + self._kill_cursors_executor.close() + await self._process_kill_cursors() + await self._topology.close() + if self._encrypter: + # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. + await self._encrypter.close() + self._closed = True + if not _IS_SYNC: + await asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.aclosing. @@ -1698,6 +1702,7 @@ async def _get_topology(self) -> Topology: """ if not self._opened: self._resolve_uri() + self._init_background() await self._topology.open() async with self._lock: self._kill_cursors_executor.open() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index c8b05445c9..0ec78a3acd 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -857,9 +857,10 @@ def __init__( server_monitoring_mode=options.server_monitoring_mode, ) + self._topology = Topology(self._topology_settings) + self._opened = False self._closed = False - self._init_background() self._resolve_uri_info.update( { @@ -876,17 +877,8 @@ def __init__( self._get_topology() # type: ignore[unused-coroutine] self._encrypter = None - if self._options.auto_encryption_opts: - from pymongo.synchronous.encryption import _Encrypter - - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - MongoClient._clients[self._topology._topology_id] = self - def _resolve_uri(self) -> None: keyword_opts = self._resolve_uri_info["keyword_opts"] seeds = set() @@ -983,6 +975,17 @@ def _resolve_uri(self) -> None: self._topology = Topology(self._topology_settings) + if self._options.auto_encryption_opts: + from pymongo.synchronous.encryption import _Encrypter + + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) + self._timeout = self._options.timeout + + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + MongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] ) -> common._CaseInsensitiveDictionary: @@ -1666,19 +1669,20 @@ def close(self) -> None: self._end_sessions(session_ids) # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. - self._kill_cursors_executor.close() - self._process_kill_cursors() - self._topology.close() - if self._encrypter: - # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. - self._encrypter.close() - self._closed = True - if not _IS_SYNC: - asyncio.gather( - self._topology.cleanup_monitors(), # type: ignore[func-returns-value] - self._kill_cursors_executor.join(), # type: ignore[func-returns-value] - return_exceptions=True, - ) + if self._opened: + self._kill_cursors_executor.close() + self._process_kill_cursors() + self._topology.close() + if self._encrypter: + # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. + self._encrypter.close() + self._closed = True + if not _IS_SYNC: + asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.closing. @@ -1692,6 +1696,7 @@ def _get_topology(self) -> Topology: """ if not self._opened: self._resolve_uri() + self._init_background() self._topology.open() with self._lock: self._kill_cursors_executor.open() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 5777868288..fed8ffc531 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1058,12 +1058,8 @@ async def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - if _IS_SYNC: - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - else: - kc_task = client._kill_cursors_executor._task - self.assertFalse(kc_task and not kc_task.done()) + # _kill_cursors_executor is initialized upon client connection + self.assertFalse(hasattr(client, "_kill_cursors_executor")) # Using the client should open topology and start the thread. await client.admin.command("ping") self.assertTrue(client._topology._opened) diff --git a/test/test_client.py b/test/test_client.py index 6d5759822a..da1c0fe584 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1031,12 +1031,8 @@ def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - if _IS_SYNC: - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - else: - kc_task = client._kill_cursors_executor._task - self.assertFalse(kc_task and not kc_task.done()) + # _kill_cursors_executor is initialized upon client connection + self.assertFalse(hasattr(client, "_kill_cursors_executor")) # Using the client should open topology and start the thread. client.admin.command("ping") self.assertTrue(client._topology._opened) From d94743be7eb2a225d60541094ab3ff588fd8db20 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 10:53:37 -0700 Subject: [PATCH 09/56] only define topology after uri resolution --- pymongo/asynchronous/mongo_client.py | 68 ++++++++++++++++------------ pymongo/synchronous/mongo_client.py | 68 ++++++++++++++++------------ test/asynchronous/test_client.py | 55 ++++++++++++---------- test/asynchronous/unified_format.py | 6 ++- test/test_client.py | 53 ++++++++++++---------- test/unified_format.py | 6 ++- 6 files changed, 147 insertions(+), 109 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 0aebd40199..cc25f79537 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -128,6 +128,7 @@ from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager + from pymongo.asynchronous.encryption import _Encrypter from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern @@ -840,26 +841,26 @@ def __init__( options.read_concern, ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, - ) - - self._topology = Topology(self._topology_settings) + # self._topology_settings = TopologySettings( + # seeds=seeds, + # replica_set_name=options.replica_set_name, + # pool_class=pool_class, + # pool_options=options.pool_options, + # monitor_class=monitor_class, + # condition_class=condition_class, + # local_threshold_ms=options.local_threshold_ms, + # server_selection_timeout=options.server_selection_timeout, + # server_selector=options.server_selector, + # heartbeat_frequency=options.heartbeat_frequency, + # fqdn=fqdn, + # direct_connection=options.direct_connection, + # load_balanced=options.load_balanced, + # srv_service_name=srv_service_name, + # srv_max_hosts=srv_max_hosts, + # server_monitoring_mode=options.server_monitoring_mode, + # ) + # + # self._topology = Topology(self._topology_settings) self._opened = False self._closed = False @@ -878,7 +879,7 @@ def __init__( if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout def _resolve_uri(self) -> None: @@ -1018,7 +1019,6 @@ async def aconnect(self) -> None: await self._get_topology() def _init_background(self, old_pid: Optional[int] = None) -> None: - self._topology = Topology(self._topology_settings) # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1235,14 +1235,20 @@ def options(self) -> ClientOptions: def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + if hasattr(self, "_topology"): + return self._topology == other._topology + else: + raise InvalidOperation("Cannot perform operation until client is connected") return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + if hasattr(self, "_topology"): + return hash(self._topology) + else: + raise InvalidOperation("Cannot perform operation until client is connected") def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1278,7 +1284,9 @@ def option_repr(option: str, value: Any) -> str: return ", ".join(options) def __repr__(self) -> str: - return f"{type(self).__name__}({self._repr_helper()})" + if hasattr(self, "_topology"): + return f"{type(self).__name__}({self._repr_helper()})" + raise InvalidOperation("Cannot perform operation until client is connected") def __getattr__(self, name: str) -> database.AsyncDatabase[_DocumentType]: """Get a database by name. @@ -1670,12 +1678,12 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - session_ids = self._topology.pop_all_sessions() - if session_ids: - await self._end_sessions(session_ids) - # Stop the periodic task thread and then send pending killCursor - # requests before closing the topology. if self._opened: + session_ids = self._topology.pop_all_sessions() + if session_ids: + await self._end_sessions(session_ids) + # Stop the periodic task thread and then send pending killCursor + # requests before closing the topology. self._kill_cursors_executor.close() await self._process_kill_cursors() await self._topology.close() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 0ec78a3acd..05b3a8460c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -130,6 +130,7 @@ from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager + from pymongo.synchronous.encryption import _Encrypter from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -838,26 +839,26 @@ def __init__( options.read_concern, ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, - ) - - self._topology = Topology(self._topology_settings) + # self._topology_settings = TopologySettings( + # seeds=seeds, + # replica_set_name=options.replica_set_name, + # pool_class=pool_class, + # pool_options=options.pool_options, + # monitor_class=monitor_class, + # condition_class=condition_class, + # local_threshold_ms=options.local_threshold_ms, + # server_selection_timeout=options.server_selection_timeout, + # server_selector=options.server_selector, + # heartbeat_frequency=options.heartbeat_frequency, + # fqdn=fqdn, + # direct_connection=options.direct_connection, + # load_balanced=options.load_balanced, + # srv_service_name=srv_service_name, + # srv_max_hosts=srv_max_hosts, + # server_monitoring_mode=options.server_monitoring_mode, + # ) + # + # self._topology = Topology(self._topology_settings) self._opened = False self._closed = False @@ -876,7 +877,7 @@ def __init__( if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout def _resolve_uri(self) -> None: @@ -1016,7 +1017,6 @@ def _connect(self) -> None: self._get_topology() def _init_background(self, old_pid: Optional[int] = None) -> None: - self._topology = Topology(self._topology_settings) # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1233,14 +1233,20 @@ def options(self) -> ClientOptions: def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + if hasattr(self, "_topology"): + return self._topology == other._topology + else: + raise InvalidOperation("Cannot perform operation until client is connected") return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + if hasattr(self, "_topology"): + return hash(self._topology) + else: + raise InvalidOperation("Cannot perform operation until client is connected") def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1276,7 +1282,9 @@ def option_repr(option: str, value: Any) -> str: return ", ".join(options) def __repr__(self) -> str: - return f"{type(self).__name__}({self._repr_helper()})" + if hasattr(self, "_topology"): + return f"{type(self).__name__}({self._repr_helper()})" + raise InvalidOperation("Cannot perform operation until client is connected") def __getattr__(self, name: str) -> database.Database[_DocumentType]: """Get a database by name. @@ -1664,12 +1672,12 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - session_ids = self._topology.pop_all_sessions() - if session_ids: - self._end_sessions(session_ids) - # Stop the periodic task thread and then send pending killCursor - # requests before closing the topology. if self._opened: + session_ids = self._topology.pop_all_sessions() + if session_ids: + self._end_sessions(session_ids) + # Stop the periodic task thread and then send pending killCursor + # requests before closing the topology. self._kill_cursors_executor.close() self._process_kill_cursors() self._topology.close() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index fed8ffc531..7fe83be5df 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -132,7 +132,7 @@ class AsyncClientUnitTest(AsyncUnitTest): async def asyncSetUp(self) -> None: self.client = await self.async_rs_or_single_client( - connect=False, serverSelectionTimeoutMS=100 + connect=True, serverSelectionTimeoutMS=100 ) @pytest.fixture(autouse=True) @@ -258,7 +258,7 @@ async def test_get_default_database(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), - connect=False, + connect=True, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) # Test that default doesn't override the URI value. @@ -274,7 +274,7 @@ async def test_get_default_database(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), - connect=False, + connect=True, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) @@ -292,14 +292,14 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await self.async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=True) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), - connect=False, + connect=True, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) @@ -317,7 +317,7 @@ async def test_get_database_default_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await self.async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=True) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) async def test_primary_read_pref_with_tags(self): @@ -817,16 +817,18 @@ async def test_init_disconnected(self): self.assertIsInstance(await c.is_primary, bool) c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(await c.is_mongos, bool) - c = await self.async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=True) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = await self.async_rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) c = await self.async_rs_or_single_client(connect=False) + await c.aconnect() self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) c = await self.async_rs_or_single_client(connect=False) + await c.aconnect() self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) if async_client_context.is_rs: @@ -848,32 +850,36 @@ async def test_init_disconnected_with_auth(self): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=False) + c = await self.async_rs_or_single_client(seed, connect=True) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) - c = await self.async_rs_or_single_client("invalid.com", connect=False) + c = await self.async_rs_or_single_client("invalid.com", connect=True) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) c1 = self.simple_client("a", connect=False) + await c1.aconnect() c2 = self.simple_client("b", connect=False) + await c2.aconnect() # Seeds differ: self.assertNotEqual(c1, c2) c1 = self.simple_client(["a", "b", "c"], connect=False) + await c1.aconnect() c2 = self.simple_client(["c", "a", "b"], connect=False) + await c2.aconnect() # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=False) + c = await self.async_rs_or_single_client(seed, connect=True) self.assertIn(c, {async_client_context.client}) - c = await self.async_rs_or_single_client("invalid.com", connect=False) + c = await self.async_rs_or_single_client("invalid.com", connect=True) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -897,7 +903,7 @@ async def test_repr(self): connect=False, document_class=SON, ) - + await client.aconnect() the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -905,8 +911,9 @@ async def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - + await client.close() async with eval(the_repr) as client_two: + await client_two.aconnect() self.assertEqual(client_two, client) client = self.simple_client( @@ -918,6 +925,7 @@ async def test_repr(self): wtimeoutms=100, connect=False, ) + await client.aconnect() the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -928,6 +936,7 @@ async def test_repr(self): self.assertIn("wtimeoutms=100", the_repr) async with eval(the_repr) as client_two: + await client_two.aconnect() self.assertEqual(client_two, client) async def test_getters(self): @@ -1053,9 +1062,7 @@ async def test_close_stops_kill_cursors_thread(self): self.assertTrue(client._kill_cursors_executor._stopped) async def test_uri_connect_option(self): - # Ensure that topology is not opened if connect=False. client = await self.async_rs_client(connect=False) - self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. # _kill_cursors_executor is initialized upon client connection @@ -1070,13 +1077,6 @@ async def test_uri_connect_option(self): kc_task = client._kill_cursors_executor._task self.assertTrue(kc_task and not kc_task.done()) - async def test_close_does_not_open_servers(self): - client = await self.async_rs_client(connect=False) - topology = client._topology - self.assertEqual(topology._servers, {}) - await client.close() - self.assertEqual(topology._servers, {}) - async def test_close_closes_sockets(self): client = await self.async_rs_client() await client.test.test.find_one() @@ -1620,10 +1620,11 @@ def init(self, *args): finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore - def test_small_heartbeat_frequency_ms(self): + async def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: - AsyncMongoClient(uri) + client = AsyncMongoClient(uri) + await client.aconnect() self.assertIn("heartbeatFrequencyMS", str(context.exception)) @@ -1896,19 +1897,25 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() async def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index d4c3d40d20..d1351d1822 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1290,6 +1290,7 @@ def check_events(self, spec): ) for idx, expected_event in enumerate(events): + print(expected_event, actual_events[idx]) self.match_evaluator.match_event(expected_event, actual_events[idx]) if has_server_connection_id: @@ -1344,6 +1345,8 @@ def format_logs(log_list): ignore_logs = client.get("ignoreMessages", []) if ignore_logs: actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) + print(f"{ignore_logs=}") + print(f"{actual_logs=}") if client.get("ignoreExtraMessages", False): actual_logs = actual_logs[: len(client["messages"])] @@ -1354,7 +1357,8 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") - + print(f"{expected_data=}") + print(f"{actual_data=}") if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") diff --git a/test/test_client.py b/test/test_client.py index da1c0fe584..9952b50fb2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -130,7 +130,7 @@ class ClientUnitTest(UnitTest): client: MongoClient def setUp(self) -> None: - self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + self.client = self.rs_or_single_client(connect=True, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -254,7 +254,7 @@ def test_iteration(self): def test_get_default_database(self): c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), - connect=False, + connect=True, ) self.assertEqual(Database(c, "foo"), c.get_default_database()) # Test that default doesn't override the URI value. @@ -270,7 +270,7 @@ def test_get_default_database(self): c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), - connect=False, + connect=True, ) self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) @@ -288,13 +288,13 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - c = self.rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=True) self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), - connect=False, + connect=True, ) self.assertEqual(Database(c, "foo"), c.get_database()) @@ -312,7 +312,7 @@ def test_get_database_default_with_authsource(self): client_context.host, client_context.port, ) - c = self.rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=True) self.assertEqual(Database(c, "foo"), c.get_database()) def test_primary_read_pref_with_tags(self): @@ -792,16 +792,18 @@ def test_init_disconnected(self): self.assertIsInstance(c.is_primary, bool) c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.is_mongos, bool) - c = self.rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=True) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = self.rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) c = self.rs_or_single_client(connect=False) + c._connect() self.assertFalse(c.primary) self.assertFalse(c.secondaries) c = self.rs_or_single_client(connect=False) + c._connect() self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) if client_context.is_rs: @@ -823,32 +825,36 @@ def test_init_disconnected_with_auth(self): def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=False) + c = self.rs_or_single_client(seed, connect=True) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - c = self.rs_or_single_client("invalid.com", connect=False) + c = self.rs_or_single_client("invalid.com", connect=True) self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) c1 = self.simple_client("a", connect=False) + c1._connect() c2 = self.simple_client("b", connect=False) + c2._connect() # Seeds differ: self.assertNotEqual(c1, c2) c1 = self.simple_client(["a", "b", "c"], connect=False) + c1._connect() c2 = self.simple_client(["c", "a", "b"], connect=False) + c2._connect() # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=False) + c = self.rs_or_single_client(seed, connect=True) self.assertIn(c, {client_context.client}) - c = self.rs_or_single_client("invalid.com", connect=False) + c = self.rs_or_single_client("invalid.com", connect=True) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): @@ -872,7 +878,7 @@ def test_repr(self): connect=False, document_class=SON, ) - + client._connect() the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -880,8 +886,9 @@ def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - + client.close() with eval(the_repr) as client_two: + client_two._connect() self.assertEqual(client_two, client) client = self.simple_client( @@ -893,6 +900,7 @@ def test_repr(self): wtimeoutms=100, connect=False, ) + client._connect() the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -903,6 +911,7 @@ def test_repr(self): self.assertIn("wtimeoutms=100", the_repr) with eval(the_repr) as client_two: + client_two._connect() self.assertEqual(client_two, client) def test_getters(self): @@ -1026,9 +1035,7 @@ def test_close_stops_kill_cursors_thread(self): self.assertTrue(client._kill_cursors_executor._stopped) def test_uri_connect_option(self): - # Ensure that topology is not opened if connect=False. client = self.rs_client(connect=False) - self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. # _kill_cursors_executor is initialized upon client connection @@ -1043,13 +1050,6 @@ def test_uri_connect_option(self): kc_task = client._kill_cursors_executor._task self.assertTrue(kc_task and not kc_task.done()) - def test_close_does_not_open_servers(self): - client = self.rs_client(connect=False) - topology = client._topology - self.assertEqual(topology._servers, {}) - client.close() - self.assertEqual(topology._servers, {}) - def test_close_closes_sockets(self): client = self.rs_client() client.test.test.find_one() @@ -1580,7 +1580,8 @@ def init(self, *args): def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: - MongoClient(uri) + client = MongoClient(uri) + client._connect() self.assertIn("heartbeatFrequencyMS", str(context.exception)) @@ -1853,19 +1854,25 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") diff --git a/test/unified_format.py b/test/unified_format.py index 293fbd97ca..39dca0bd86 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1277,6 +1277,7 @@ def check_events(self, spec): ) for idx, expected_event in enumerate(events): + print(expected_event, actual_events[idx]) self.match_evaluator.match_event(expected_event, actual_events[idx]) if has_server_connection_id: @@ -1331,6 +1332,8 @@ def format_logs(log_list): ignore_logs = client.get("ignoreMessages", []) if ignore_logs: actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) + print(f"{ignore_logs=}") + print(f"{actual_logs=}") if client.get("ignoreExtraMessages", False): actual_logs = actual_logs[: len(client["messages"])] @@ -1341,7 +1344,8 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") - + print(f"{expected_data=}") + print(f"{actual_data=}") if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") From ad20606105dd3aa91aa3257d8f37a177228c20df Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 16:45:35 -0700 Subject: [PATCH 10/56] okay turns out it was *too* lazy HAHA --- pymongo/asynchronous/mongo_client.py | 72 ++++++++++++++-------------- pymongo/asynchronous/topology.py | 2 - pymongo/synchronous/mongo_client.py | 72 ++++++++++++++-------------- pymongo/synchronous/topology.py | 2 - pymongo/uri_parser.py | 5 +- 5 files changed, 77 insertions(+), 76 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index cc25f79537..29684df421 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -763,9 +763,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class - self._resolve_uri_info: dict[str, Any] = {"keyword_opts": keyword_opts} + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() + is_srv = False username = None password = None dbase = None @@ -789,6 +790,7 @@ def __init__( srv_max_hosts=srv_max_hosts, ) seeds.update(res["nodelist"]) + is_srv = res["is_srv"] or is_srv username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase @@ -840,33 +842,33 @@ def __init__( options.write_concern, options.read_concern, ) - - # self._topology_settings = TopologySettings( - # seeds=seeds, - # replica_set_name=options.replica_set_name, - # pool_class=pool_class, - # pool_options=options.pool_options, - # monitor_class=monitor_class, - # condition_class=condition_class, - # local_threshold_ms=options.local_threshold_ms, - # server_selection_timeout=options.server_selection_timeout, - # server_selector=options.server_selector, - # heartbeat_frequency=options.heartbeat_frequency, - # fqdn=fqdn, - # direct_connection=options.direct_connection, - # load_balanced=options.load_balanced, - # srv_service_name=srv_service_name, - # srv_max_hosts=srv_max_hosts, - # server_monitoring_mode=options.server_monitoring_mode, - # ) - # - # self._topology = Topology(self._topology_settings) + if not is_srv: + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=options.replica_set_name, + pool_class=pool_class, + pool_options=options.pool_options, + monitor_class=monitor_class, + condition_class=condition_class, + local_threshold_ms=options.local_threshold_ms, + server_selection_timeout=options.server_selection_timeout, + server_selector=options.server_selector, + heartbeat_frequency=options.heartbeat_frequency, + fqdn=fqdn, + direct_connection=options.direct_connection, + load_balanced=options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=options.server_monitoring_mode, + ) + self._topology = Topology(self._topology_settings) self._opened = False self._closed = False - self._resolve_uri_info.update( + self._resolve_srv_info.update( { + "is_srv": is_srv, "username": username, "password": password, "dbase": dbase, @@ -882,8 +884,8 @@ def __init__( self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout - def _resolve_uri(self) -> None: - keyword_opts = self._resolve_uri_info["keyword_opts"] + def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") @@ -899,7 +901,7 @@ def _resolve_uri(self) -> None: timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout ) - res = uri_parser._lookup_uri( + res = uri_parser._parse_srv( entity, self._port, validate=True, @@ -943,10 +945,10 @@ def _resolve_uri(self) -> None: opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", self._resolve_uri_info["username"]) - password = opts.get("password", self._resolve_uri_info["password"]) + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) self._options = ClientOptions( - username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners @@ -960,22 +962,21 @@ def _resolve_uri(self) -> None: self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, - pool_class=self._resolve_uri_info["pool_class"], + pool_class=self._resolve_srv_info["pool_class"], pool_options=self._options.pool_options, - monitor_class=self._resolve_uri_info["monitor_class"], - condition_class=self._resolve_uri_info["condition_class"], + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], local_threshold_ms=self._options.local_threshold_ms, server_selection_timeout=self._options.server_selection_timeout, server_selector=self._options.server_selector, heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._resolve_uri_info["fqdn"], + fqdn=self._resolve_srv_info["fqdn"], direct_connection=self._options.direct_connection, load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) if self._options.auto_encryption_opts: @@ -1709,7 +1710,8 @@ async def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: - self._resolve_uri() + if self._resolve_srv_info["is_srv"]: + self._resolve_srv() self._init_background() await self._topology.open() async with self._lock: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 76f0fb6cde..f3c024a905 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -567,7 +567,6 @@ async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if td_old.topology_type not in SRV_POLLING_TOPOLOGIES: return self._description = _updated_topology_description_srv_polling(self._description, seedlist) - await self._update_servers() if self._publish_tp: @@ -983,7 +982,6 @@ def _create_pool_for_monitor(self, address: _Address) -> Pool: pause_enabled=False, server_api=options.server_api, ) - return self._settings.pool_class( address, monitor_pool_options, handshake=False, client_id=self._topology_id ) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 05b3a8460c..2ae224a634 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -761,9 +761,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class - self._resolve_uri_info: dict[str, Any] = {"keyword_opts": keyword_opts} + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() + is_srv = False username = None password = None dbase = None @@ -787,6 +788,7 @@ def __init__( srv_max_hosts=srv_max_hosts, ) seeds.update(res["nodelist"]) + is_srv = res["is_srv"] or is_srv username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase @@ -838,33 +840,33 @@ def __init__( options.write_concern, options.read_concern, ) - - # self._topology_settings = TopologySettings( - # seeds=seeds, - # replica_set_name=options.replica_set_name, - # pool_class=pool_class, - # pool_options=options.pool_options, - # monitor_class=monitor_class, - # condition_class=condition_class, - # local_threshold_ms=options.local_threshold_ms, - # server_selection_timeout=options.server_selection_timeout, - # server_selector=options.server_selector, - # heartbeat_frequency=options.heartbeat_frequency, - # fqdn=fqdn, - # direct_connection=options.direct_connection, - # load_balanced=options.load_balanced, - # srv_service_name=srv_service_name, - # srv_max_hosts=srv_max_hosts, - # server_monitoring_mode=options.server_monitoring_mode, - # ) - # - # self._topology = Topology(self._topology_settings) + if not is_srv: + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=options.replica_set_name, + pool_class=pool_class, + pool_options=options.pool_options, + monitor_class=monitor_class, + condition_class=condition_class, + local_threshold_ms=options.local_threshold_ms, + server_selection_timeout=options.server_selection_timeout, + server_selector=options.server_selector, + heartbeat_frequency=options.heartbeat_frequency, + fqdn=fqdn, + direct_connection=options.direct_connection, + load_balanced=options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=options.server_monitoring_mode, + ) + self._topology = Topology(self._topology_settings) self._opened = False self._closed = False - self._resolve_uri_info.update( + self._resolve_srv_info.update( { + "is_srv": is_srv, "username": username, "password": password, "dbase": dbase, @@ -880,8 +882,8 @@ def __init__( self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout - def _resolve_uri(self) -> None: - keyword_opts = self._resolve_uri_info["keyword_opts"] + def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") @@ -897,7 +899,7 @@ def _resolve_uri(self) -> None: timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout ) - res = uri_parser._lookup_uri( + res = uri_parser._parse_srv( entity, self._port, validate=True, @@ -941,10 +943,10 @@ def _resolve_uri(self) -> None: opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. - username = opts.get("username", self._resolve_uri_info["username"]) - password = opts.get("password", self._resolve_uri_info["password"]) + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) self._options = ClientOptions( - username, password, self._resolve_uri_info["dbase"], opts, _IS_SYNC + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) self._event_listeners = self._options.pool_options._event_listeners @@ -958,22 +960,21 @@ def _resolve_uri(self) -> None: self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, - pool_class=self._resolve_uri_info["pool_class"], + pool_class=self._resolve_srv_info["pool_class"], pool_options=self._options.pool_options, - monitor_class=self._resolve_uri_info["monitor_class"], - condition_class=self._resolve_uri_info["condition_class"], + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], local_threshold_ms=self._options.local_threshold_ms, server_selection_timeout=self._options.server_selection_timeout, server_selector=self._options.server_selector, heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._resolve_uri_info["fqdn"], + fqdn=self._resolve_srv_info["fqdn"], direct_connection=self._options.direct_connection, load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) if self._options.auto_encryption_opts: @@ -1703,7 +1704,8 @@ def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: - self._resolve_uri() + if self._resolve_srv_info["is_srv"]: + self._resolve_srv() self._init_background() self._topology.open() with self._lock: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index ea0edae919..c95ec90f29 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -567,7 +567,6 @@ def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if td_old.topology_type not in SRV_POLLING_TOPOLOGIES: return self._description = _updated_topology_description_srv_polling(self._description, seedlist) - self._update_servers() if self._publish_tp: @@ -981,7 +980,6 @@ def _create_pool_for_monitor(self, address: _Address) -> Pool: pause_enabled=False, server_api=options.server_api, ) - return self._settings.pool_class( address, monitor_pool_options, handshake=False, client_id=self._topology_id ) diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 34f7d8dca1..b4913218d4 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -486,7 +486,7 @@ def parse_uri( """ result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) result.update( - _lookup_uri( + _parse_srv( uri, default_port, validate, @@ -586,6 +586,7 @@ def _validate_uri( _check_options(nodes, options) return { + "is_srv": is_srv, "nodelist": nodes, "username": user, "password": passwd, @@ -596,7 +597,7 @@ def _validate_uri( } -def _lookup_uri( +def _parse_srv( uri: str, default_port: Optional[int] = DEFAULT_PORT, validate: bool = True, From dfa0639977e8592aa51d55f0d2322f361784b579 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 16:56:09 -0700 Subject: [PATCH 11/56] cleanup --- pymongo/asynchronous/topology.py | 2 ++ pymongo/synchronous/topology.py | 2 ++ test/asynchronous/test_client.py | 5 +---- test/asynchronous/test_dns.py | 1 + test/asynchronous/unified_format.py | 5 ----- test/test_client.py | 5 +---- test/test_dns.py | 1 + test/unified_format.py | 5 ----- 8 files changed, 8 insertions(+), 18 deletions(-) diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index f3c024a905..76f0fb6cde 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -567,6 +567,7 @@ async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if td_old.topology_type not in SRV_POLLING_TOPOLOGIES: return self._description = _updated_topology_description_srv_polling(self._description, seedlist) + await self._update_servers() if self._publish_tp: @@ -982,6 +983,7 @@ def _create_pool_for_monitor(self, address: _Address) -> Pool: pause_enabled=False, server_api=options.server_api, ) + return self._settings.pool_class( address, monitor_pool_options, handshake=False, client_id=self._topology_id ) diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index c95ec90f29..ea0edae919 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -567,6 +567,7 @@ def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: if td_old.topology_type not in SRV_POLLING_TOPOLOGIES: return self._description = _updated_topology_description_srv_polling(self._description, seedlist) + self._update_servers() if self._publish_tp: @@ -980,6 +981,7 @@ def _create_pool_for_monitor(self, address: _Address) -> Pool: pause_enabled=False, server_api=options.server_api, ) + return self._settings.pool_class( address, monitor_pool_options, handshake=False, client_id=self._topology_id ) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 7fe83be5df..a26c2b7c69 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -903,7 +903,6 @@ async def test_repr(self): connect=False, document_class=SON, ) - await client.aconnect() the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -925,7 +924,6 @@ async def test_repr(self): wtimeoutms=100, connect=False, ) - await client.aconnect() the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -1623,8 +1621,7 @@ def init(self, *args): async def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: - client = AsyncMongoClient(uri) - await client.aconnect() + AsyncMongoClient(uri) self.assertIn("heartbeatFrequencyMS", str(context.exception)) diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index 65e4454a1e..d929877d6d 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -141,6 +141,7 @@ async def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True + client = self.simple_client(uri, **copts) if client._options.connect: await client.aconnect() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index d1351d1822..1547432cab 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1290,7 +1290,6 @@ def check_events(self, spec): ) for idx, expected_event in enumerate(events): - print(expected_event, actual_events[idx]) self.match_evaluator.match_event(expected_event, actual_events[idx]) if has_server_connection_id: @@ -1345,8 +1344,6 @@ def format_logs(log_list): ignore_logs = client.get("ignoreMessages", []) if ignore_logs: actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) - print(f"{ignore_logs=}") - print(f"{actual_logs=}") if client.get("ignoreExtraMessages", False): actual_logs = actual_logs[: len(client["messages"])] @@ -1357,8 +1354,6 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") - print(f"{expected_data=}") - print(f"{actual_data=}") if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") diff --git a/test/test_client.py b/test/test_client.py index 9952b50fb2..19acfcc162 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -878,7 +878,6 @@ def test_repr(self): connect=False, document_class=SON, ) - client._connect() the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -900,7 +899,6 @@ def test_repr(self): wtimeoutms=100, connect=False, ) - client._connect() the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -1580,8 +1578,7 @@ def init(self, *args): def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: - client = MongoClient(uri) - client._connect() + MongoClient(uri) self.assertIn("heartbeatFrequencyMS", str(context.exception)) diff --git a/test/test_dns.py b/test/test_dns.py index 8bb8dd5e07..df460e698f 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -141,6 +141,7 @@ def run_test(self): # Our test certs don't support the SRV hosts used in these # tests. copts["tlsAllowInvalidHostnames"] = True + client = self.simple_client(uri, **copts) if client._options.connect: client._connect() diff --git a/test/unified_format.py b/test/unified_format.py index 39dca0bd86..19d2cfa0ee 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1277,7 +1277,6 @@ def check_events(self, spec): ) for idx, expected_event in enumerate(events): - print(expected_event, actual_events[idx]) self.match_evaluator.match_event(expected_event, actual_events[idx]) if has_server_connection_id: @@ -1332,8 +1331,6 @@ def format_logs(log_list): ignore_logs = client.get("ignoreMessages", []) if ignore_logs: actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) - print(f"{ignore_logs=}") - print(f"{actual_logs=}") if client.get("ignoreExtraMessages", False): actual_logs = actual_logs[: len(client["messages"])] @@ -1344,8 +1341,6 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") - print(f"{expected_data=}") - print(f"{actual_data=}") if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") From 57edcbcb45725d12ed8a0f57dd6b4db14d9dda5a Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 17:02:25 -0700 Subject: [PATCH 12/56] more cleanup --- test/asynchronous/test_client.py | 24 ++++++++++++------------ test/asynchronous/unified_format.py | 1 + test/test_client.py | 22 +++++++++++----------- test/unified_format.py | 1 + 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index a26c2b7c69..3634b8a797 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -132,7 +132,7 @@ class AsyncClientUnitTest(AsyncUnitTest): async def asyncSetUp(self) -> None: self.client = await self.async_rs_or_single_client( - connect=True, serverSelectionTimeoutMS=100 + connect=False, serverSelectionTimeoutMS=100 ) @pytest.fixture(autouse=True) @@ -258,7 +258,7 @@ async def test_get_default_database(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), - connect=True, + connect=False, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) # Test that default doesn't override the URI value. @@ -274,7 +274,7 @@ async def test_get_default_database(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), - connect=True, + connect=False, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) @@ -292,14 +292,14 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await self.async_rs_or_single_client(uri, connect=True) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), - connect=True, + connect=False, ) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) @@ -317,7 +317,7 @@ async def test_get_database_default_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await self.async_rs_or_single_client(uri, connect=True) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) async def test_primary_read_pref_with_tags(self): @@ -817,7 +817,7 @@ async def test_init_disconnected(self): self.assertIsInstance(await c.is_primary, bool) c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(await c.is_mongos, bool) - c = await self.async_rs_or_single_client(connect=True) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) @@ -850,12 +850,12 @@ async def test_init_disconnected_with_auth(self): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=True) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) - c = await self.async_rs_or_single_client("invalid.com", connect=True) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) @@ -877,9 +877,9 @@ async def test_equality(self): async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=True) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertIn(c, {async_client_context.client}) - c = await self.async_rs_or_single_client("invalid.com", connect=True) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -1618,7 +1618,7 @@ def init(self, *args): finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore - async def test_small_heartbeat_frequency_ms(self): + def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: AsyncMongoClient(uri) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 1547432cab..d4c3d40d20 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1354,6 +1354,7 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") + if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") diff --git a/test/test_client.py b/test/test_client.py index 19acfcc162..452a01aff9 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -130,7 +130,7 @@ class ClientUnitTest(UnitTest): client: MongoClient def setUp(self) -> None: - self.client = self.rs_or_single_client(connect=True, serverSelectionTimeoutMS=100) + self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -254,7 +254,7 @@ def test_iteration(self): def test_get_default_database(self): c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), - connect=True, + connect=False, ) self.assertEqual(Database(c, "foo"), c.get_default_database()) # Test that default doesn't override the URI value. @@ -270,7 +270,7 @@ def test_get_default_database(self): c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), - connect=True, + connect=False, ) self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) @@ -288,13 +288,13 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - c = self.rs_or_single_client(uri, connect=True) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), - connect=True, + connect=False, ) self.assertEqual(Database(c, "foo"), c.get_database()) @@ -312,7 +312,7 @@ def test_get_database_default_with_authsource(self): client_context.host, client_context.port, ) - c = self.rs_or_single_client(uri, connect=True) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) def test_primary_read_pref_with_tags(self): @@ -792,7 +792,7 @@ def test_init_disconnected(self): self.assertIsInstance(c.is_primary, bool) c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.is_mongos, bool) - c = self.rs_or_single_client(connect=True) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) @@ -825,12 +825,12 @@ def test_init_disconnected_with_auth(self): def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=True) + c = self.rs_or_single_client(seed, connect=False) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - c = self.rs_or_single_client("invalid.com", connect=True) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) @@ -852,9 +852,9 @@ def test_equality(self): def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=True) + c = self.rs_or_single_client(seed, connect=False) self.assertIn(c, {client_context.client}) - c = self.rs_or_single_client("invalid.com", connect=True) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): diff --git a/test/unified_format.py b/test/unified_format.py index 19d2cfa0ee..293fbd97ca 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1341,6 +1341,7 @@ def format_logs(log_list): ) for expected_msg, actual_msg in zip(client["messages"], actual_logs): expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") + if "failureIsRedacted" in expected_msg: self.assertIn("failure", actual_data) should_redact = expected_msg.pop("failureIsRedacted") From 58a58a053e6f9cc539cb1d8008048fe398c8c85e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Mar 2025 17:06:59 -0700 Subject: [PATCH 13/56] more cleanup --- test/asynchronous/test_client.py | 9 --------- test/test_client.py | 9 --------- 2 files changed, 18 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 3634b8a797..ebd504c6cb 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -824,11 +824,9 @@ async def test_init_disconnected(self): c = await self.async_rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) c = await self.async_rs_or_single_client(connect=False) - await c.aconnect() self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) c = await self.async_rs_or_single_client(connect=False) - await c.aconnect() self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) if async_client_context.is_rs: @@ -860,17 +858,13 @@ async def test_equality(self): self.assertTrue(async_client_context.client != c) c1 = self.simple_client("a", connect=False) - await c1.aconnect() c2 = self.simple_client("b", connect=False) - await c2.aconnect() # Seeds differ: self.assertNotEqual(c1, c2) c1 = self.simple_client(["a", "b", "c"], connect=False) - await c1.aconnect() c2 = self.simple_client(["c", "a", "b"], connect=False) - await c2.aconnect() # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) @@ -910,9 +904,7 @@ async def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - await client.close() async with eval(the_repr) as client_two: - await client_two.aconnect() self.assertEqual(client_two, client) client = self.simple_client( @@ -934,7 +926,6 @@ async def test_repr(self): self.assertIn("wtimeoutms=100", the_repr) async with eval(the_repr) as client_two: - await client_two.aconnect() self.assertEqual(client_two, client) async def test_getters(self): diff --git a/test/test_client.py b/test/test_client.py index 452a01aff9..b64ed8900e 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -799,11 +799,9 @@ def test_init_disconnected(self): c = self.rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) c = self.rs_or_single_client(connect=False) - c._connect() self.assertFalse(c.primary) self.assertFalse(c.secondaries) c = self.rs_or_single_client(connect=False) - c._connect() self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) if client_context.is_rs: @@ -835,17 +833,13 @@ def test_equality(self): self.assertTrue(client_context.client != c) c1 = self.simple_client("a", connect=False) - c1._connect() c2 = self.simple_client("b", connect=False) - c2._connect() # Seeds differ: self.assertNotEqual(c1, c2) c1 = self.simple_client(["a", "b", "c"], connect=False) - c1._connect() c2 = self.simple_client(["c", "a", "b"], connect=False) - c2._connect() # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) @@ -885,9 +879,7 @@ def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - client.close() with eval(the_repr) as client_two: - client_two._connect() self.assertEqual(client_two, client) client = self.simple_client( @@ -909,7 +901,6 @@ def test_repr(self): self.assertIn("wtimeoutms=100", the_repr) with eval(the_repr) as client_two: - client_two._connect() self.assertEqual(client_two, client) def test_getters(self): From 35a41e90ec4d03948b36f5748106f489c7e07e1d Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 10:58:08 -0700 Subject: [PATCH 14/56] fix fork tests --- pymongo/asynchronous/mongo_client.py | 20 ++++++++++---------- pymongo/synchronous/mongo_client.py | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 29684df421..acb534419f 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -861,10 +861,11 @@ def __init__( srv_max_hosts=srv_max_hosts, server_monitoring_mode=options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) self._opened = False self._closed = False + if not is_srv: + self._init_background(first=True) self._resolve_srv_info.update( { @@ -977,7 +978,6 @@ def _resolve_srv(self) -> None: srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) if self._options.auto_encryption_opts: from pymongo.asynchronous.encryption import _Encrypter @@ -985,11 +985,6 @@ def _resolve_srv(self) -> None: self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - AsyncMongoClient._clients[self._topology._topology_id] = self - def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] ) -> common._CaseInsensitiveDictionary: @@ -1019,7 +1014,12 @@ async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" await self._get_topology() - def _init_background(self, old_pid: Optional[int] = None) -> None: + def _init_background(self, old_pid: Optional[int] = None, first=False) -> None: + self._topology = Topology(self._topology_settings) + if first and _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + AsyncMongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1679,7 +1679,7 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if self._opened: + if hasattr(self, "_topology"): session_ids = self._topology.pop_all_sessions() if session_ids: await self._end_sessions(session_ids) @@ -1712,7 +1712,7 @@ async def _get_topology(self) -> Topology: if not self._opened: if self._resolve_srv_info["is_srv"]: self._resolve_srv() - self._init_background() + self._init_background(first=True) await self._topology.open() async with self._lock: self._kill_cursors_executor.open() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2ae224a634..c636dfcb66 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -859,10 +859,11 @@ def __init__( srv_max_hosts=srv_max_hosts, server_monitoring_mode=options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) self._opened = False self._closed = False + if not is_srv: + self._init_background(first=True) self._resolve_srv_info.update( { @@ -975,7 +976,6 @@ def _resolve_srv(self) -> None: srv_max_hosts=srv_max_hosts, server_monitoring_mode=self._options.server_monitoring_mode, ) - self._topology = Topology(self._topology_settings) if self._options.auto_encryption_opts: from pymongo.synchronous.encryption import _Encrypter @@ -983,11 +983,6 @@ def _resolve_srv(self) -> None: self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - MongoClient._clients[self._topology._topology_id] = self - def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] ) -> common._CaseInsensitiveDictionary: @@ -1017,7 +1012,12 @@ def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" self._get_topology() - def _init_background(self, old_pid: Optional[int] = None) -> None: + def _init_background(self, old_pid: Optional[int] = None, first=False) -> None: + self._topology = Topology(self._topology_settings) + if first and _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + MongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1673,7 +1673,7 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if self._opened: + if hasattr(self, "_topology"): session_ids = self._topology.pop_all_sessions() if session_ids: self._end_sessions(session_ids) @@ -1706,7 +1706,7 @@ def _get_topology(self) -> Topology: if not self._opened: if self._resolve_srv_info["is_srv"]: self._resolve_srv() - self._init_background() + self._init_background(first=True) self._topology.open() with self._lock: self._kill_cursors_executor.open() From d343311e5e1bc4a138f452d42c46c0e8a787f4f1 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 11:01:40 -0700 Subject: [PATCH 15/56] fix typing --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index acb534419f..b1a910c9ef 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1014,7 +1014,7 @@ async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" await self._get_topology() - def _init_background(self, old_pid: Optional[int] = None, first=False) -> None: + def _init_background(self, old_pid: Optional[int] = None, first: bool = False) -> None: self._topology = Topology(self._topology_settings) if first and _HAS_REGISTER_AT_FORK: # Add this client to the list of weakly referenced items. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index c636dfcb66..0d6959faf4 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1012,7 +1012,7 @@ def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" self._get_topology() - def _init_background(self, old_pid: Optional[int] = None, first=False) -> None: + def _init_background(self, old_pid: Optional[int] = None, first: bool = False) -> None: self._topology = Topology(self._topology_settings) if first and _HAS_REGISTER_AT_FORK: # Add this client to the list of weakly referenced items. From d03c78f2849f3c6228da2d77ebc54140defc02fd Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 11:20:06 -0700 Subject: [PATCH 16/56] determine is_srv differently --- pymongo/asynchronous/mongo_client.py | 3 ++- pymongo/synchronous/mongo_client.py | 3 ++- pymongo/uri_parser.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b1a910c9ef..987de51ca6 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -114,6 +114,7 @@ _Pipeline, ) from pymongo.uri_parser import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, @@ -789,8 +790,8 @@ def __init__( normalize=False, srv_max_hosts=srv_max_hosts, ) + is_srv = entity.startswith(SRV_SCHEME) seeds.update(res["nodelist"]) - is_srv = res["is_srv"] or is_srv username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 0d6959faf4..b2b3c3187a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -113,6 +113,7 @@ _Pipeline, ) from pymongo.uri_parser import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, @@ -787,8 +788,8 @@ def __init__( normalize=False, srv_max_hosts=srv_max_hosts, ) + is_srv = entity.startswith(SRV_SCHEME) seeds.update(res["nodelist"]) - is_srv = res["is_srv"] or is_srv username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index b4913218d4..09f974fe8a 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -586,7 +586,6 @@ def _validate_uri( _check_options(nodes, options) return { - "is_srv": is_srv, "nodelist": nodes, "username": user, "password": passwd, From e1d091fefd666321735f4c91c860dee5216ba661 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 11:47:45 -0700 Subject: [PATCH 17/56] fix test --- test/asynchronous/test_client.py | 8 ++++++-- test/test_client.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index ebd504c6cb..b46bfed77a 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1054,8 +1054,12 @@ async def test_uri_connect_option(self): client = await self.async_rs_client(connect=False) # Ensure kill cursors thread has not been started. - # _kill_cursors_executor is initialized upon client connection - self.assertFalse(hasattr(client, "_kill_cursors_executor")) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. await client.admin.command("ping") self.assertTrue(client._topology._opened) diff --git a/test/test_client.py b/test/test_client.py index b64ed8900e..487704b306 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1027,8 +1027,12 @@ def test_uri_connect_option(self): client = self.rs_client(connect=False) # Ensure kill cursors thread has not been started. - # _kill_cursors_executor is initialized upon client connection - self.assertFalse(hasattr(client, "_kill_cursors_executor")) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. client.admin.command("ping") self.assertTrue(client._topology._opened) From 8efd549e8a54edaf13d0cda051740b7abef1addd Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 13:15:41 -0700 Subject: [PATCH 18/56] fix encrypter --- pymongo/asynchronous/mongo_client.py | 6 +- pymongo/synchronous/mongo_client.py | 6 +- .../test_read_write_concern_spec.py | 344 ------------------ test/test_read_write_concern_spec.py | 340 ----------------- uv.lock | 1 - 5 files changed, 10 insertions(+), 687 deletions(-) delete mode 100644 test/asynchronous/test_read_write_concern_spec.py delete mode 100644 test/test_read_write_concern_spec.py diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 987de51ca6..046882aeb2 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -843,6 +843,7 @@ def __init__( options.write_concern, options.read_concern, ) + self._encrypter: Optional[_Encrypter] = None if not is_srv: self._topology_settings = TopologySettings( seeds=seeds, @@ -862,6 +863,10 @@ def __init__( srv_max_hosts=srv_max_hosts, server_monitoring_mode=options.server_monitoring_mode, ) + if self._options.auto_encryption_opts: + from pymongo.asynchronous.encryption import _Encrypter + + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._opened = False self._closed = False @@ -883,7 +888,6 @@ def __init__( if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout def _resolve_srv(self) -> None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index b2b3c3187a..3e1057bcfd 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -841,6 +841,7 @@ def __init__( options.write_concern, options.read_concern, ) + self._encrypter: Optional[_Encrypter] = None if not is_srv: self._topology_settings = TopologySettings( seeds=seeds, @@ -860,6 +861,10 @@ def __init__( srv_max_hosts=srv_max_hosts, server_monitoring_mode=options.server_monitoring_mode, ) + if self._options.auto_encryption_opts: + from pymongo.synchronous.encryption import _Encrypter + + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._opened = False self._closed = False @@ -881,7 +886,6 @@ def __init__( if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter: Optional[_Encrypter] = None self._timeout = self._options.timeout def _resolve_srv(self) -> None: diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py deleted file mode 100644 index 3fb13ba194..0000000000 --- a/test/asynchronous/test_read_write_concern_spec.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright 2018-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run the read and write concern tests.""" -from __future__ import annotations - -import json -import os -import sys -import warnings -from pathlib import Path - -sys.path[0:0] = [""] - -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes -from test.utils import OvertCommandListener - -from pymongo import DESCENDING -from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - WriteConcernError, - WriteError, - WTimeoutError, -) -from pymongo.operations import IndexModel, InsertOne -from pymongo.read_concern import ReadConcern -from pymongo.write_concern import WriteConcern - -_IS_SYNC = False - -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") - - -class TestReadWriteConcernSpec(AsyncIntegrationTest): - async def test_omit_default_read_write_concern(self): - listener = OvertCommandListener() - # Client with default readConcern and writeConcern - client = await self.async_rs_or_single_client(event_listeners=[listener]) - collection = client.pymongo_test.collection - # Prepare for tests of find() and aggregate(). - await collection.insert_many([{} for _ in range(10)]) - self.addAsyncCleanup(collection.drop) - self.addAsyncCleanup(client.pymongo_test.collection2.drop) - # Commands MUST NOT send the default read/write concern to the server. - - async def rename_and_drop(): - # Ensure collection exists. - await collection.insert_one({}) - await collection.rename("collection2") - await client.pymongo_test.collection2.drop() - - async def insert_command_default_write_concern(): - await collection.database.command( - "insert", "collection", documents=[{}], write_concern=WriteConcern() - ) - - async def aggregate_op(): - await (await collection.aggregate([])).to_list() - - ops = [ - ("aggregate", aggregate_op), - ("find", lambda: collection.find().to_list()), - ("insert_one", lambda: collection.insert_one({})), - ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), - ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), - ("delete_one", lambda: collection.delete_one({})), - ("delete_many", lambda: collection.delete_many({})), - ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), - ("rename_and_drop", rename_and_drop), - ("command", insert_command_default_write_concern), - ] - - for name, f in ops: - listener.reset() - await f() - - self.assertGreaterEqual(len(listener.started_events), 1) - for _i, event in enumerate(listener.started_events): - self.assertNotIn( - "readConcern", - event.command, - f"{name} sent default readConcern with {event.command_name}", - ) - self.assertNotIn( - "writeConcern", - event.command, - f"{name} sent default writeConcern with {event.command_name}", - ) - - async def assertWriteOpsRaise(self, write_concern, expected_exception): - wc = write_concern.document - # Set socket timeout to avoid indefinite stalls - client = await self.async_rs_or_single_client( - w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 - ) - db = client.get_database("pymongo_test") - coll = db.test - - async def insert_command(): - await coll.database.command( - "insert", - "new_collection", - documents=[{}], - writeConcern=write_concern.document, - parse_write_concern_error=True, - ) - - ops = [ - ("insert_one", lambda: coll.insert_one({})), - ("insert_many", lambda: coll.insert_many([{}, {}])), - ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), - ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), - ("delete_one", lambda: coll.delete_one({})), - ("delete_many", lambda: coll.delete_many({})), - ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), - ("command", insert_command), - ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), - # SERVER-46668 Delete all the documents in the collection to - # workaround a hang in createIndexes. - ("delete_many", lambda: coll.delete_many({})), - ("create_index", lambda: coll.create_index([("a", DESCENDING)])), - ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), - ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), - ("create", lambda: db.create_collection("new")), - ("rename", lambda: coll.rename("new")), - ("drop", lambda: db.new.drop()), - ] - # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. - if async_client_context.version[:2] != (3, 6): - ops.append(("drop_database", lambda: client.drop_database(db))) - - for name, f in ops: - # Ensure insert_many and bulk_write still raise BulkWriteError. - if name in ("insert_many", "bulk_write"): - expected = BulkWriteError - else: - expected = expected_exception - with self.assertRaises(expected, msg=name) as cm: - await f() - if expected == BulkWriteError: - bulk_result = cm.exception.details - assert bulk_result is not None - wc_errors = bulk_result["writeConcernErrors"] - self.assertTrue(wc_errors) - - @async_client_context.require_replica_set - async def test_raise_write_concern_error(self): - self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") - assert async_client_context.w is not None - await self.assertWriteOpsRaise( - WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError - ) - - @async_client_context.require_secondaries_count(1) - @async_client_context.require_test_commands - async def test_raise_wtimeout(self): - self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") - self.addAsyncCleanup(self.enable_replication, async_client_context.client) - # Disable replication to guarantee a wtimeout error. - await self.disable_replication(async_client_context.client) - await self.assertWriteOpsRaise( - WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError - ) - - @async_client_context.require_failCommand_fail_point - async def test_error_includes_errInfo(self): - expected_wce = { - "code": 100, - "codeName": "UnsatisfiableWriteConcern", - "errmsg": "Not enough data-bearing nodes", - "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, - } - cause_wce = { - "configureFailPoint": "failCommand", - "mode": {"times": 2}, - "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, - } - async with self.fail_point(cause_wce): - # Write concern error on insert includes errInfo. - with self.assertRaises(WriteConcernError) as ctx: - await self.db.test.insert_one({}) - self.assertEqual(ctx.exception.details, expected_wce) - - # Test bulk_write as well. - with self.assertRaises(BulkWriteError) as ctx: - await self.db.test.bulk_write([InsertOne({})]) - expected_details = { - "writeErrors": [], - "writeConcernErrors": [expected_wce], - "nInserted": 1, - "nUpserted": 0, - "nMatched": 0, - "nModified": 0, - "nRemoved": 0, - "upserted": [], - } - self.assertEqual(ctx.exception.details, expected_details) - - @async_client_context.require_version_min(4, 9) - async def test_write_error_details_exposes_errinfo(self): - listener = OvertCommandListener() - client = await self.async_rs_or_single_client(event_listeners=[listener]) - db = client.errinfotest - self.addAsyncCleanup(client.drop_database, "errinfotest") - validator = {"x": {"$type": "string"}} - await db.create_collection("test", validator=validator) - with self.assertRaises(WriteError) as ctx: - await db.test.insert_one({"x": 1}) - self.assertEqual(ctx.exception.code, 121) - self.assertIsNotNone(ctx.exception.details) - assert ctx.exception.details is not None - self.assertIsNotNone(ctx.exception.details.get("errInfo")) - for event in listener.succeeded_events: - if event.command_name == "insert": - self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) - break - else: - self.fail("Couldn't find insert event.") - - -def normalize_write_concern(concern): - result = {} - for key in concern: - if key.lower() == "wtimeoutms": - result["wtimeout"] = concern[key] - elif key == "journal": - result["j"] = concern[key] - else: - result[key] = concern[key] - return result - - -def create_connection_string_test(test_case): - def run_test(self): - uri = test_case["uri"] - valid = test_case["valid"] - warning = test_case["warning"] - - if not valid: - if warning is False: - self.assertRaises( - (ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False - ) - else: - with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) - self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False) - else: - client = AsyncMongoClient(uri, connect=False) - if "writeConcern" in test_case: - document = client.write_concern.document - self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) - if "readConcern" in test_case: - document = client.read_concern.document - self.assertEqual(document, test_case["readConcern"]) - - return run_test - - -def create_document_test(test_case): - def run_test(self): - valid = test_case["valid"] - - if "writeConcern" in test_case: - normalized = normalize_write_concern(test_case["writeConcern"]) - if not valid: - self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) - else: - write_concern = WriteConcern(**normalized) - self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) - self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) - self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) - if "readConcern" in test_case: - # Any string for 'level' is equally valid - read_concern = ReadConcern(**test_case["readConcern"]) - self.assertEqual(read_concern.document, test_case["readConcernDocument"]) - self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) - - return run_test - - -def create_tests(): - for dirpath, _, filenames in os.walk(TEST_PATH): - dirname = os.path.split(dirpath)[-1] - - if dirname == "operation": - # This directory is tested by TestOperations. - continue - elif dirname == "connection-string": - create_test = create_connection_string_test - else: - create_test = create_document_test - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as test_stream: - test_cases = json.load(test_stream)["tests"] - - fname = os.path.splitext(filename)[0] - for test_case in test_cases: - new_test = create_test(test_case) - test_name = "test_{}_{}_{}".format( - dirname.replace("-", "_"), - fname.replace("-", "_"), - str(test_case["description"].lower().replace(" ", "_")), - ) - - new_test.__name__ = test_name - setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) - - -create_tests() - - -# Generate unified tests. -# PyMongo does not support MapReduce. -globals().update( - generate_test_classes( - os.path.join(TEST_PATH, "operation"), - module=__name__, - expected_failures=["MapReduce .*"], - ) -) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py deleted file mode 100644 index 8543991f72..0000000000 --- a/test/test_read_write_concern_spec.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2018-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run the read and write concern tests.""" -from __future__ import annotations - -import json -import os -import sys -import warnings -from pathlib import Path - -sys.path[0:0] = [""] - -from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener - -from pymongo import DESCENDING -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - WriteConcernError, - WriteError, - WTimeoutError, -) -from pymongo.operations import IndexModel, InsertOne -from pymongo.read_concern import ReadConcern -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.write_concern import WriteConcern - -_IS_SYNC = True - -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") - - -class TestReadWriteConcernSpec(IntegrationTest): - def test_omit_default_read_write_concern(self): - listener = OvertCommandListener() - # Client with default readConcern and writeConcern - client = self.rs_or_single_client(event_listeners=[listener]) - collection = client.pymongo_test.collection - # Prepare for tests of find() and aggregate(). - collection.insert_many([{} for _ in range(10)]) - self.addCleanup(collection.drop) - self.addCleanup(client.pymongo_test.collection2.drop) - # Commands MUST NOT send the default read/write concern to the server. - - def rename_and_drop(): - # Ensure collection exists. - collection.insert_one({}) - collection.rename("collection2") - client.pymongo_test.collection2.drop() - - def insert_command_default_write_concern(): - collection.database.command( - "insert", "collection", documents=[{}], write_concern=WriteConcern() - ) - - def aggregate_op(): - (collection.aggregate([])).to_list() - - ops = [ - ("aggregate", aggregate_op), - ("find", lambda: collection.find().to_list()), - ("insert_one", lambda: collection.insert_one({})), - ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), - ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), - ("delete_one", lambda: collection.delete_one({})), - ("delete_many", lambda: collection.delete_many({})), - ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), - ("rename_and_drop", rename_and_drop), - ("command", insert_command_default_write_concern), - ] - - for name, f in ops: - listener.reset() - f() - - self.assertGreaterEqual(len(listener.started_events), 1) - for _i, event in enumerate(listener.started_events): - self.assertNotIn( - "readConcern", - event.command, - f"{name} sent default readConcern with {event.command_name}", - ) - self.assertNotIn( - "writeConcern", - event.command, - f"{name} sent default writeConcern with {event.command_name}", - ) - - def assertWriteOpsRaise(self, write_concern, expected_exception): - wc = write_concern.document - # Set socket timeout to avoid indefinite stalls - client = self.rs_or_single_client( - w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 - ) - db = client.get_database("pymongo_test") - coll = db.test - - def insert_command(): - coll.database.command( - "insert", - "new_collection", - documents=[{}], - writeConcern=write_concern.document, - parse_write_concern_error=True, - ) - - ops = [ - ("insert_one", lambda: coll.insert_one({})), - ("insert_many", lambda: coll.insert_many([{}, {}])), - ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), - ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), - ("delete_one", lambda: coll.delete_one({})), - ("delete_many", lambda: coll.delete_many({})), - ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), - ("command", insert_command), - ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), - # SERVER-46668 Delete all the documents in the collection to - # workaround a hang in createIndexes. - ("delete_many", lambda: coll.delete_many({})), - ("create_index", lambda: coll.create_index([("a", DESCENDING)])), - ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), - ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), - ("create", lambda: db.create_collection("new")), - ("rename", lambda: coll.rename("new")), - ("drop", lambda: db.new.drop()), - ] - # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. - if client_context.version[:2] != (3, 6): - ops.append(("drop_database", lambda: client.drop_database(db))) - - for name, f in ops: - # Ensure insert_many and bulk_write still raise BulkWriteError. - if name in ("insert_many", "bulk_write"): - expected = BulkWriteError - else: - expected = expected_exception - with self.assertRaises(expected, msg=name) as cm: - f() - if expected == BulkWriteError: - bulk_result = cm.exception.details - assert bulk_result is not None - wc_errors = bulk_result["writeConcernErrors"] - self.assertTrue(wc_errors) - - @client_context.require_replica_set - def test_raise_write_concern_error(self): - self.addCleanup(client_context.client.drop_database, "pymongo_test") - assert client_context.w is not None - self.assertWriteOpsRaise( - WriteConcern(w=client_context.w + 1, wtimeout=1), WriteConcernError - ) - - @client_context.require_secondaries_count(1) - @client_context.require_test_commands - def test_raise_wtimeout(self): - self.addCleanup(client_context.client.drop_database, "pymongo_test") - self.addCleanup(self.enable_replication, client_context.client) - # Disable replication to guarantee a wtimeout error. - self.disable_replication(client_context.client) - self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) - - @client_context.require_failCommand_fail_point - def test_error_includes_errInfo(self): - expected_wce = { - "code": 100, - "codeName": "UnsatisfiableWriteConcern", - "errmsg": "Not enough data-bearing nodes", - "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, - } - cause_wce = { - "configureFailPoint": "failCommand", - "mode": {"times": 2}, - "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, - } - with self.fail_point(cause_wce): - # Write concern error on insert includes errInfo. - with self.assertRaises(WriteConcernError) as ctx: - self.db.test.insert_one({}) - self.assertEqual(ctx.exception.details, expected_wce) - - # Test bulk_write as well. - with self.assertRaises(BulkWriteError) as ctx: - self.db.test.bulk_write([InsertOne({})]) - expected_details = { - "writeErrors": [], - "writeConcernErrors": [expected_wce], - "nInserted": 1, - "nUpserted": 0, - "nMatched": 0, - "nModified": 0, - "nRemoved": 0, - "upserted": [], - } - self.assertEqual(ctx.exception.details, expected_details) - - @client_context.require_version_min(4, 9) - def test_write_error_details_exposes_errinfo(self): - listener = OvertCommandListener() - client = self.rs_or_single_client(event_listeners=[listener]) - db = client.errinfotest - self.addCleanup(client.drop_database, "errinfotest") - validator = {"x": {"$type": "string"}} - db.create_collection("test", validator=validator) - with self.assertRaises(WriteError) as ctx: - db.test.insert_one({"x": 1}) - self.assertEqual(ctx.exception.code, 121) - self.assertIsNotNone(ctx.exception.details) - assert ctx.exception.details is not None - self.assertIsNotNone(ctx.exception.details.get("errInfo")) - for event in listener.succeeded_events: - if event.command_name == "insert": - self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) - break - else: - self.fail("Couldn't find insert event.") - - -def normalize_write_concern(concern): - result = {} - for key in concern: - if key.lower() == "wtimeoutms": - result["wtimeout"] = concern[key] - elif key == "journal": - result["j"] = concern[key] - else: - result[key] = concern[key] - return result - - -def create_connection_string_test(test_case): - def run_test(self): - uri = test_case["uri"] - valid = test_case["valid"] - warning = test_case["warning"] - - if not valid: - if warning is False: - self.assertRaises((ConfigurationError, ValueError), MongoClient, uri, connect=False) - else: - with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) - self.assertRaises(UserWarning, MongoClient, uri, connect=False) - else: - client = MongoClient(uri, connect=False) - if "writeConcern" in test_case: - document = client.write_concern.document - self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) - if "readConcern" in test_case: - document = client.read_concern.document - self.assertEqual(document, test_case["readConcern"]) - - return run_test - - -def create_document_test(test_case): - def run_test(self): - valid = test_case["valid"] - - if "writeConcern" in test_case: - normalized = normalize_write_concern(test_case["writeConcern"]) - if not valid: - self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) - else: - write_concern = WriteConcern(**normalized) - self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) - self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) - self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) - if "readConcern" in test_case: - # Any string for 'level' is equally valid - read_concern = ReadConcern(**test_case["readConcern"]) - self.assertEqual(read_concern.document, test_case["readConcernDocument"]) - self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) - - return run_test - - -def create_tests(): - for dirpath, _, filenames in os.walk(TEST_PATH): - dirname = os.path.split(dirpath)[-1] - - if dirname == "operation": - # This directory is tested by TestOperations. - continue - elif dirname == "connection-string": - create_test = create_connection_string_test - else: - create_test = create_document_test - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as test_stream: - test_cases = json.load(test_stream)["tests"] - - fname = os.path.splitext(filename)[0] - for test_case in test_cases: - new_test = create_test(test_case) - test_name = "test_{}_{}_{}".format( - dirname.replace("-", "_"), - fname.replace("-", "_"), - str(test_case["description"].lower().replace(" ", "_")), - ) - - new_test.__name__ = test_name - setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) - - -create_tests() - - -# Generate unified tests. -# PyMongo does not support MapReduce. -globals().update( - generate_test_classes( - os.path.join(TEST_PATH, "operation"), - module=__name__, - expected_failures=["MapReduce .*"], - ) -) - - -if __name__ == "__main__": - unittest.main() diff --git a/uv.lock b/uv.lock index a2e951e76c..3189a88696 100644 --- a/uv.lock +++ b/uv.lock @@ -997,7 +997,6 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5 [[package]] name = "pymongo" -version = "4.12.0.dev0" source = { editable = "." } dependencies = [ { name = "dnspython" }, From 4c06dec95c7a6742c447d473eb24e49cfd82a67a Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 14:49:47 -0700 Subject: [PATCH 19/56] undoing unintended changes --- test/asynchronous/test_client.py | 4 + .../test_read_write_concern_spec.py | 344 ++++++++++++++++++ test/test_client.py | 4 + test/test_read_write_concern_spec.py | 340 +++++++++++++++++ uv.lock | 1 + 5 files changed, 693 insertions(+) create mode 100644 test/asynchronous/test_read_write_concern_spec.py create mode 100644 test/test_read_write_concern_spec.py diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index b46bfed77a..079b243658 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -897,6 +897,7 @@ async def test_repr(self): connect=False, document_class=SON, ) + the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -904,6 +905,7 @@ async def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) + async with eval(the_repr) as client_two: self.assertEqual(client_two, client) @@ -1051,7 +1053,9 @@ async def test_close_stops_kill_cursors_thread(self): self.assertTrue(client._kill_cursors_executor._stopped) async def test_uri_connect_option(self): + # Ensure that topology is not opened if connect=False. client = await self.async_rs_client(connect=False) + self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. if _IS_SYNC: diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py new file mode 100644 index 0000000000..3fb13ba194 --- /dev/null +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -0,0 +1,344 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the read and write concern tests.""" +from __future__ import annotations + +import json +import os +import sys +import warnings +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import OvertCommandListener + +from pymongo import DESCENDING +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + WriteConcernError, + WriteError, + WTimeoutError, +) +from pymongo.operations import IndexModel, InsertOne +from pymongo.read_concern import ReadConcern +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") + + +class TestReadWriteConcernSpec(AsyncIntegrationTest): + async def test_omit_default_read_write_concern(self): + listener = OvertCommandListener() + # Client with default readConcern and writeConcern + client = await self.async_rs_or_single_client(event_listeners=[listener]) + collection = client.pymongo_test.collection + # Prepare for tests of find() and aggregate(). + await collection.insert_many([{} for _ in range(10)]) + self.addAsyncCleanup(collection.drop) + self.addAsyncCleanup(client.pymongo_test.collection2.drop) + # Commands MUST NOT send the default read/write concern to the server. + + async def rename_and_drop(): + # Ensure collection exists. + await collection.insert_one({}) + await collection.rename("collection2") + await client.pymongo_test.collection2.drop() + + async def insert_command_default_write_concern(): + await collection.database.command( + "insert", "collection", documents=[{}], write_concern=WriteConcern() + ) + + async def aggregate_op(): + await (await collection.aggregate([])).to_list() + + ops = [ + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), + ("insert_one", lambda: collection.insert_one({})), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), + ("command", insert_command_default_write_concern), + ] + + for name, f in ops: + listener.reset() + await f() + + self.assertGreaterEqual(len(listener.started_events), 1) + for _i, event in enumerate(listener.started_events): + self.assertNotIn( + "readConcern", + event.command, + f"{name} sent default readConcern with {event.command_name}", + ) + self.assertNotIn( + "writeConcern", + event.command, + f"{name} sent default writeConcern with {event.command_name}", + ) + + async def assertWriteOpsRaise(self, write_concern, expected_exception): + wc = write_concern.document + # Set socket timeout to avoid indefinite stalls + client = await self.async_rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) + db = client.get_database("pymongo_test") + coll = db.test + + async def insert_command(): + await coll.database.command( + "insert", + "new_collection", + documents=[{}], + writeConcern=write_concern.document, + parse_write_concern_error=True, + ) + + ops = [ + ("insert_one", lambda: coll.insert_one({})), + ("insert_many", lambda: coll.insert_many([{}, {}])), + ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: coll.delete_one({})), + ("delete_many", lambda: coll.delete_many({})), + ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), + ("command", insert_command), + ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), + # SERVER-46668 Delete all the documents in the collection to + # workaround a hang in createIndexes. + ("delete_many", lambda: coll.delete_many({})), + ("create_index", lambda: coll.create_index([("a", DESCENDING)])), + ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), + ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), + ("create", lambda: db.create_collection("new")), + ("rename", lambda: coll.rename("new")), + ("drop", lambda: db.new.drop()), + ] + # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. + if async_client_context.version[:2] != (3, 6): + ops.append(("drop_database", lambda: client.drop_database(db))) + + for name, f in ops: + # Ensure insert_many and bulk_write still raise BulkWriteError. + if name in ("insert_many", "bulk_write"): + expected = BulkWriteError + else: + expected = expected_exception + with self.assertRaises(expected, msg=name) as cm: + await f() + if expected == BulkWriteError: + bulk_result = cm.exception.details + assert bulk_result is not None + wc_errors = bulk_result["writeConcernErrors"] + self.assertTrue(wc_errors) + + @async_client_context.require_replica_set + async def test_raise_write_concern_error(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + assert async_client_context.w is not None + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w + 1, wtimeout=1), WriteConcernError + ) + + @async_client_context.require_secondaries_count(1) + @async_client_context.require_test_commands + async def test_raise_wtimeout(self): + self.addAsyncCleanup(async_client_context.client.drop_database, "pymongo_test") + self.addAsyncCleanup(self.enable_replication, async_client_context.client) + # Disable replication to guarantee a wtimeout error. + await self.disable_replication(async_client_context.client) + await self.assertWriteOpsRaise( + WriteConcern(w=async_client_context.w, wtimeout=1), WTimeoutError + ) + + @async_client_context.require_failCommand_fail_point + async def test_error_includes_errInfo(self): + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + async with self.fail_point(cause_wce): + # Write concern error on insert includes errInfo. + with self.assertRaises(WriteConcernError) as ctx: + await self.db.test.insert_one({}) + self.assertEqual(ctx.exception.details, expected_wce) + + # Test bulk_write as well. + with self.assertRaises(BulkWriteError) as ctx: + await self.db.test.bulk_write([InsertOne({})]) + expected_details = { + "writeErrors": [], + "writeConcernErrors": [expected_wce], + "nInserted": 1, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + self.assertEqual(ctx.exception.details, expected_details) + + @async_client_context.require_version_min(4, 9) + async def test_write_error_details_exposes_errinfo(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener]) + db = client.errinfotest + self.addAsyncCleanup(client.drop_database, "errinfotest") + validator = {"x": {"$type": "string"}} + await db.create_collection("test", validator=validator) + with self.assertRaises(WriteError) as ctx: + await db.test.insert_one({"x": 1}) + self.assertEqual(ctx.exception.code, 121) + self.assertIsNotNone(ctx.exception.details) + assert ctx.exception.details is not None + self.assertIsNotNone(ctx.exception.details.get("errInfo")) + for event in listener.succeeded_events: + if event.command_name == "insert": + self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) + break + else: + self.fail("Couldn't find insert event.") + + +def normalize_write_concern(concern): + result = {} + for key in concern: + if key.lower() == "wtimeoutms": + result["wtimeout"] = concern[key] + elif key == "journal": + result["j"] = concern[key] + else: + result[key] = concern[key] + return result + + +def create_connection_string_test(test_case): + def run_test(self): + uri = test_case["uri"] + valid = test_case["valid"] + warning = test_case["warning"] + + if not valid: + if warning is False: + self.assertRaises( + (ConfigurationError, ValueError), AsyncMongoClient, uri, connect=False + ) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + self.assertRaises(UserWarning, AsyncMongoClient, uri, connect=False) + else: + client = AsyncMongoClient(uri, connect=False) + if "writeConcern" in test_case: + document = client.write_concern.document + self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) + if "readConcern" in test_case: + document = client.read_concern.document + self.assertEqual(document, test_case["readConcern"]) + + return run_test + + +def create_document_test(test_case): + def run_test(self): + valid = test_case["valid"] + + if "writeConcern" in test_case: + normalized = normalize_write_concern(test_case["writeConcern"]) + if not valid: + self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) + else: + write_concern = WriteConcern(**normalized) + self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) + self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) + self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) + if "readConcern" in test_case: + # Any string for 'level' is equally valid + read_concern = ReadConcern(**test_case["readConcern"]) + self.assertEqual(read_concern.document, test_case["readConcernDocument"]) + self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) + + return run_test + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + if dirname == "operation": + # This directory is tested by TestOperations. + continue + elif dirname == "connection-string": + create_test = create_connection_string_test + else: + create_test = create_document_test + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as test_stream: + test_cases = json.load(test_stream)["tests"] + + fname = os.path.splitext(filename)[0] + for test_case in test_cases: + new_test = create_test(test_case) + test_name = "test_{}_{}_{}".format( + dirname.replace("-", "_"), + fname.replace("-", "_"), + str(test_case["description"].lower().replace(" ", "_")), + ) + + new_test.__name__ = test_name + setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) + + +create_tests() + + +# Generate unified tests. +# PyMongo does not support MapReduce. +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "operation"), + module=__name__, + expected_failures=["MapReduce .*"], + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 487704b306..cf50347ab5 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -872,6 +872,7 @@ def test_repr(self): connect=False, document_class=SON, ) + the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) @@ -879,6 +880,7 @@ def test_repr(self): self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) + with eval(the_repr) as client_two: self.assertEqual(client_two, client) @@ -1024,7 +1026,9 @@ def test_close_stops_kill_cursors_thread(self): self.assertTrue(client._kill_cursors_executor._stopped) def test_uri_connect_option(self): + # Ensure that topology is not opened if connect=False. client = self.rs_client(connect=False) + self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. if _IS_SYNC: diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py new file mode 100644 index 0000000000..8543991f72 --- /dev/null +++ b/test/test_read_write_concern_spec.py @@ -0,0 +1,340 @@ +# Copyright 2018-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the read and write concern tests.""" +from __future__ import annotations + +import json +import os +import sys +import warnings +from pathlib import Path + +sys.path[0:0] = [""] + +from test import IntegrationTest, client_context, unittest +from test.unified_format import generate_test_classes +from test.utils import OvertCommandListener + +from pymongo import DESCENDING +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + WriteConcernError, + WriteError, + WTimeoutError, +) +from pymongo.operations import IndexModel, InsertOne +from pymongo.read_concern import ReadConcern +from pymongo.synchronous.mongo_client import MongoClient +from pymongo.write_concern import WriteConcern + +_IS_SYNC = True + +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") +else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") + + +class TestReadWriteConcernSpec(IntegrationTest): + def test_omit_default_read_write_concern(self): + listener = OvertCommandListener() + # Client with default readConcern and writeConcern + client = self.rs_or_single_client(event_listeners=[listener]) + collection = client.pymongo_test.collection + # Prepare for tests of find() and aggregate(). + collection.insert_many([{} for _ in range(10)]) + self.addCleanup(collection.drop) + self.addCleanup(client.pymongo_test.collection2.drop) + # Commands MUST NOT send the default read/write concern to the server. + + def rename_and_drop(): + # Ensure collection exists. + collection.insert_one({}) + collection.rename("collection2") + client.pymongo_test.collection2.drop() + + def insert_command_default_write_concern(): + collection.database.command( + "insert", "collection", documents=[{}], write_concern=WriteConcern() + ) + + def aggregate_op(): + (collection.aggregate([])).to_list() + + ops = [ + ("aggregate", aggregate_op), + ("find", lambda: collection.find().to_list()), + ("insert_one", lambda: collection.insert_one({})), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), + ("command", insert_command_default_write_concern), + ] + + for name, f in ops: + listener.reset() + f() + + self.assertGreaterEqual(len(listener.started_events), 1) + for _i, event in enumerate(listener.started_events): + self.assertNotIn( + "readConcern", + event.command, + f"{name} sent default readConcern with {event.command_name}", + ) + self.assertNotIn( + "writeConcern", + event.command, + f"{name} sent default writeConcern with {event.command_name}", + ) + + def assertWriteOpsRaise(self, write_concern, expected_exception): + wc = write_concern.document + # Set socket timeout to avoid indefinite stalls + client = self.rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) + db = client.get_database("pymongo_test") + coll = db.test + + def insert_command(): + coll.database.command( + "insert", + "new_collection", + documents=[{}], + writeConcern=write_concern.document, + parse_write_concern_error=True, + ) + + ops = [ + ("insert_one", lambda: coll.insert_one({})), + ("insert_many", lambda: coll.insert_many([{}, {}])), + ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: coll.delete_one({})), + ("delete_many", lambda: coll.delete_many({})), + ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), + ("command", insert_command), + ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), + # SERVER-46668 Delete all the documents in the collection to + # workaround a hang in createIndexes. + ("delete_many", lambda: coll.delete_many({})), + ("create_index", lambda: coll.create_index([("a", DESCENDING)])), + ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), + ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), + ("create", lambda: db.create_collection("new")), + ("rename", lambda: coll.rename("new")), + ("drop", lambda: db.new.drop()), + ] + # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. + if client_context.version[:2] != (3, 6): + ops.append(("drop_database", lambda: client.drop_database(db))) + + for name, f in ops: + # Ensure insert_many and bulk_write still raise BulkWriteError. + if name in ("insert_many", "bulk_write"): + expected = BulkWriteError + else: + expected = expected_exception + with self.assertRaises(expected, msg=name) as cm: + f() + if expected == BulkWriteError: + bulk_result = cm.exception.details + assert bulk_result is not None + wc_errors = bulk_result["writeConcernErrors"] + self.assertTrue(wc_errors) + + @client_context.require_replica_set + def test_raise_write_concern_error(self): + self.addCleanup(client_context.client.drop_database, "pymongo_test") + assert client_context.w is not None + self.assertWriteOpsRaise( + WriteConcern(w=client_context.w + 1, wtimeout=1), WriteConcernError + ) + + @client_context.require_secondaries_count(1) + @client_context.require_test_commands + def test_raise_wtimeout(self): + self.addCleanup(client_context.client.drop_database, "pymongo_test") + self.addCleanup(self.enable_replication, client_context.client) + # Disable replication to guarantee a wtimeout error. + self.disable_replication(client_context.client) + self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) + + @client_context.require_failCommand_fail_point + def test_error_includes_errInfo(self): + expected_wce = { + "code": 100, + "codeName": "UnsatisfiableWriteConcern", + "errmsg": "Not enough data-bearing nodes", + "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, + } + cause_wce = { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, + } + with self.fail_point(cause_wce): + # Write concern error on insert includes errInfo. + with self.assertRaises(WriteConcernError) as ctx: + self.db.test.insert_one({}) + self.assertEqual(ctx.exception.details, expected_wce) + + # Test bulk_write as well. + with self.assertRaises(BulkWriteError) as ctx: + self.db.test.bulk_write([InsertOne({})]) + expected_details = { + "writeErrors": [], + "writeConcernErrors": [expected_wce], + "nInserted": 1, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } + self.assertEqual(ctx.exception.details, expected_details) + + @client_context.require_version_min(4, 9) + def test_write_error_details_exposes_errinfo(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(event_listeners=[listener]) + db = client.errinfotest + self.addCleanup(client.drop_database, "errinfotest") + validator = {"x": {"$type": "string"}} + db.create_collection("test", validator=validator) + with self.assertRaises(WriteError) as ctx: + db.test.insert_one({"x": 1}) + self.assertEqual(ctx.exception.code, 121) + self.assertIsNotNone(ctx.exception.details) + assert ctx.exception.details is not None + self.assertIsNotNone(ctx.exception.details.get("errInfo")) + for event in listener.succeeded_events: + if event.command_name == "insert": + self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) + break + else: + self.fail("Couldn't find insert event.") + + +def normalize_write_concern(concern): + result = {} + for key in concern: + if key.lower() == "wtimeoutms": + result["wtimeout"] = concern[key] + elif key == "journal": + result["j"] = concern[key] + else: + result[key] = concern[key] + return result + + +def create_connection_string_test(test_case): + def run_test(self): + uri = test_case["uri"] + valid = test_case["valid"] + warning = test_case["warning"] + + if not valid: + if warning is False: + self.assertRaises((ConfigurationError, ValueError), MongoClient, uri, connect=False) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + self.assertRaises(UserWarning, MongoClient, uri, connect=False) + else: + client = MongoClient(uri, connect=False) + if "writeConcern" in test_case: + document = client.write_concern.document + self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) + if "readConcern" in test_case: + document = client.read_concern.document + self.assertEqual(document, test_case["readConcern"]) + + return run_test + + +def create_document_test(test_case): + def run_test(self): + valid = test_case["valid"] + + if "writeConcern" in test_case: + normalized = normalize_write_concern(test_case["writeConcern"]) + if not valid: + self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) + else: + write_concern = WriteConcern(**normalized) + self.assertEqual(write_concern.document, test_case["writeConcernDocument"]) + self.assertEqual(write_concern.acknowledged, test_case["isAcknowledged"]) + self.assertEqual(write_concern.is_server_default, test_case["isServerDefault"]) + if "readConcern" in test_case: + # Any string for 'level' is equally valid + read_concern = ReadConcern(**test_case["readConcern"]) + self.assertEqual(read_concern.document, test_case["readConcernDocument"]) + self.assertEqual(not bool(read_concern.level), test_case["isServerDefault"]) + + return run_test + + +def create_tests(): + for dirpath, _, filenames in os.walk(TEST_PATH): + dirname = os.path.split(dirpath)[-1] + + if dirname == "operation": + # This directory is tested by TestOperations. + continue + elif dirname == "connection-string": + create_test = create_connection_string_test + else: + create_test = create_document_test + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as test_stream: + test_cases = json.load(test_stream)["tests"] + + fname = os.path.splitext(filename)[0] + for test_case in test_cases: + new_test = create_test(test_case) + test_name = "test_{}_{}_{}".format( + dirname.replace("-", "_"), + fname.replace("-", "_"), + str(test_case["description"].lower().replace(" ", "_")), + ) + + new_test.__name__ = test_name + setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) + + +create_tests() + + +# Generate unified tests. +# PyMongo does not support MapReduce. +globals().update( + generate_test_classes( + os.path.join(TEST_PATH, "operation"), + module=__name__, + expected_failures=["MapReduce .*"], + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/uv.lock b/uv.lock index 3189a88696..a2e951e76c 100644 --- a/uv.lock +++ b/uv.lock @@ -997,6 +997,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/07/e9/ae44ea7d7605df9e5 [[package]] name = "pymongo" +version = "4.12.0.dev0" source = { editable = "." } dependencies = [ { name = "dnspython" }, From 511fcc47d1db2a15fcef1f699fbe54566f1cf825 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 14:52:50 -0700 Subject: [PATCH 20/56] bringing back a previously deleted test --- test/asynchronous/test_client.py | 9 ++++++++- test/test_client.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 079b243658..02c676b1b4 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1074,6 +1074,13 @@ async def test_uri_connect_option(self): kc_task = client._kill_cursors_executor._task self.assertTrue(kc_task and not kc_task.done()) + async def test_close_does_not_open_servers(self): + client = await self.async_rs_client(connect=False) + topology = client._topology + self.assertEqual(topology._servers, {}) + await client.close() + self.assertEqual(topology._servers, {}) + async def test_close_closes_sockets(self): client = await self.async_rs_client() await client.test.test.find_one() @@ -1891,7 +1898,7 @@ async def test_service_name_from_kwargs(self): client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", - connect=False, + connect=True, ) await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") diff --git a/test/test_client.py b/test/test_client.py index cf50347ab5..3ebef5048d 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1047,6 +1047,13 @@ def test_uri_connect_option(self): kc_task = client._kill_cursors_executor._task self.assertTrue(kc_task and not kc_task.done()) + def test_close_does_not_open_servers(self): + client = self.rs_client(connect=False) + topology = client._topology + self.assertEqual(topology._servers, {}) + client.close() + self.assertEqual(topology._servers, {}) + def test_close_closes_sockets(self): client = self.rs_client() client.test.test.find_one() @@ -1848,7 +1855,7 @@ def test_service_name_from_kwargs(self): client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", - connect=False, + connect=True, ) client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") From 40509a1f9aa4a02f7b212f1669c2a31eb5e8849b Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 14:53:48 -0700 Subject: [PATCH 21/56] undoing unintended changes --- test/asynchronous/test_client.py | 2 +- test/test_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 02c676b1b4..0a9ebcd860 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1898,7 +1898,7 @@ async def test_service_name_from_kwargs(self): client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", - connect=True, + connect=False, ) await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") diff --git a/test/test_client.py b/test/test_client.py index 3ebef5048d..a4ebab22a1 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1855,7 +1855,7 @@ def test_service_name_from_kwargs(self): client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", - connect=True, + connect=False, ) client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") From 97e0778b78210844e87223bc92d8a3c8729362d4 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 15:21:53 -0700 Subject: [PATCH 22/56] some refactoring --- pymongo/asynchronous/mongo_client.py | 108 ++++++++++----------------- pymongo/synchronous/mongo_client.py | 108 ++++++++++----------------- 2 files changed, 80 insertions(+), 136 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 046882aeb2..1c17b125ab 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -830,48 +830,13 @@ def __init__( # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _async_create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, - ) self._encrypter: Optional[_Encrypter] = None - if not is_srv: - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, - ) - if self._options.auto_encryption_opts: - from pymongo.asynchronous.encryption import _Encrypter - - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) - - self._opened = False - self._closed = False - if not is_srv: - self._init_background(first=True) self._resolve_srv_info.update( { @@ -885,11 +850,17 @@ def __init__( "condition_class": condition_class, } ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + self._opened = False + self._closed = False + if not is_srv: + self._init_background(first=True) + if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._timeout = self._options.timeout - def _resolve_srv(self) -> None: keyword_opts = self._resolve_srv_info["keyword_opts"] seeds = set() @@ -957,38 +928,39 @@ def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._event_listeners = self._options.pool_options._event_listeners - super().__init__( - self._options.codec_options, - self._options.read_preference, - self._options.write_concern, - self._options.read_concern, - ) + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=self._options.replica_set_name, - pool_class=self._resolve_srv_info["pool_class"], - pool_options=self._options.pool_options, - monitor_class=self._resolve_srv_info["monitor_class"], - condition_class=self._resolve_srv_info["condition_class"], - local_threshold_ms=self._options.local_threshold_ms, - server_selection_timeout=self._options.server_selection_timeout, - server_selector=self._options.server_selector, - heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._resolve_srv_info["fqdn"], - direct_connection=self._options.direct_connection, - load_balanced=self._options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=self._options.server_monitoring_mode, - ) - - if self._options.auto_encryption_opts: - from pymongo.asynchronous.encryption import _Encrypter + def _init_based_on_options(self, seeds, srv_max_hosts, srv_service_name): + self._event_listeners = self._options.pool_options._event_listeners + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) + if self._options.auto_encryption_opts: + from pymongo.asynchronous.encryption import _Encrypter - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) - self._timeout = self._options.timeout + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) + self._timeout = self._options.timeout def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3e1057bcfd..4a6427eea0 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -828,48 +828,13 @@ def __init__( # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, - ) self._encrypter: Optional[_Encrypter] = None - if not is_srv: - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, - ) - if self._options.auto_encryption_opts: - from pymongo.synchronous.encryption import _Encrypter - - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) - - self._opened = False - self._closed = False - if not is_srv: - self._init_background(first=True) self._resolve_srv_info.update( { @@ -883,11 +848,17 @@ def __init__( "condition_class": condition_class, } ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + self._opened = False + self._closed = False + if not is_srv: + self._init_background(first=True) + if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._timeout = self._options.timeout - def _resolve_srv(self) -> None: keyword_opts = self._resolve_srv_info["keyword_opts"] seeds = set() @@ -955,38 +926,39 @@ def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._event_listeners = self._options.pool_options._event_listeners - super().__init__( - self._options.codec_options, - self._options.read_preference, - self._options.write_concern, - self._options.read_concern, - ) + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=self._options.replica_set_name, - pool_class=self._resolve_srv_info["pool_class"], - pool_options=self._options.pool_options, - monitor_class=self._resolve_srv_info["monitor_class"], - condition_class=self._resolve_srv_info["condition_class"], - local_threshold_ms=self._options.local_threshold_ms, - server_selection_timeout=self._options.server_selection_timeout, - server_selector=self._options.server_selector, - heartbeat_frequency=self._options.heartbeat_frequency, - fqdn=self._resolve_srv_info["fqdn"], - direct_connection=self._options.direct_connection, - load_balanced=self._options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=self._options.server_monitoring_mode, - ) - - if self._options.auto_encryption_opts: - from pymongo.synchronous.encryption import _Encrypter + def _init_based_on_options(self, seeds, srv_max_hosts, srv_service_name): + self._event_listeners = self._options.pool_options._event_listeners + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) + if self._options.auto_encryption_opts: + from pymongo.synchronous.encryption import _Encrypter - self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) - self._timeout = self._options.timeout + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) + self._timeout = self._options.timeout def _normalize_and_validate_options( self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] From 2653a56461aacdd99921498581cdb8c271b8a676 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Mar 2025 16:21:02 -0700 Subject: [PATCH 23/56] fix typing --- pymongo/asynchronous/mongo_client.py | 5 ++++- pymongo/synchronous/mongo_client.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1c17b125ab..b5e8df7103 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -44,6 +44,7 @@ AsyncContextManager, AsyncGenerator, Callable, + Collection, Coroutine, FrozenSet, Generic, @@ -930,7 +931,9 @@ def _resolve_srv(self) -> None: self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) - def _init_based_on_options(self, seeds, srv_max_hosts, srv_service_name): + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: self._event_listeners = self._options.pool_options._event_listeners super().__init__( self._options.codec_options, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 4a6427eea0..5308504882 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -42,6 +42,7 @@ TYPE_CHECKING, Any, Callable, + Collection, ContextManager, FrozenSet, Generator, @@ -928,7 +929,9 @@ def _resolve_srv(self) -> None: self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) - def _init_based_on_options(self, seeds, srv_max_hosts, srv_service_name): + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: self._event_listeners = self._options.pool_options._event_listeners super().__init__( self._options.codec_options, From 4c23ee038a36ad0ed57ef4ce491d38ce527b460b Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 13 Mar 2025 10:07:54 -0700 Subject: [PATCH 24/56] Update pymongo/asynchronous/mongo_client.py Co-authored-by: Noah Stapp --- pymongo/asynchronous/mongo_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b5e8df7103..c84649d0c2 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1219,7 +1219,7 @@ def __eq__(self, other: Any) -> bool: if hasattr(self, "_topology"): return self._topology == other._topology else: - raise InvalidOperation("Cannot perform operation until client is connected") + raise InvalidOperation("Cannot compare client equality until both clients are connected") return NotImplemented def __ne__(self, other: Any) -> bool: From bc6119980cff10b9d88c180f61c801264562672e Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 13 Mar 2025 10:08:16 -0700 Subject: [PATCH 25/56] Update pymongo/asynchronous/mongo_client.py Co-authored-by: Noah Stapp --- pymongo/asynchronous/mongo_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index c84649d0c2..15053ae527 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1229,7 +1229,7 @@ def __hash__(self) -> int: if hasattr(self, "_topology"): return hash(self._topology) else: - raise InvalidOperation("Cannot perform operation until client is connected") + raise InvalidOperation("Cannot hash client until it is connected") def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: From 8c2b36852ff39195f2ea9baca7990be09469c11a Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 11:20:30 -0700 Subject: [PATCH 26/56] respond to comments and move srv_resolver to async --- pymongo/asynchronous/mongo_client.py | 25 ++-- pymongo/asynchronous/monitor.py | 2 +- pymongo/asynchronous/srv_resolver.py | 155 ++++++++++++++++++++++ pymongo/synchronous/mongo_client.py | 27 ++-- pymongo/synchronous/monitor.py | 2 +- pymongo/{ => synchronous}/srv_resolver.py | 19 ++- pymongo/uri_parser.py | 2 +- test/asynchronous/test_client.py | 6 +- test/asynchronous/test_srv_polling.py | 17 ++- test/test_client.py | 6 +- test/test_srv_polling.py | 17 ++- 11 files changed, 228 insertions(+), 50 deletions(-) create mode 100644 pymongo/asynchronous/srv_resolver.py rename pymongo/{ => synchronous}/srv_resolver.py (92%) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 15053ae527..1d9e32e58a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1216,10 +1216,12 @@ def options(self) -> ClientOptions: def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if hasattr(self, "_topology"): + if hasattr(self, "_topology") and hasattr(other, "_topology"): return self._topology == other._topology else: - raise InvalidOperation("Cannot compare client equality until both clients are connected") + raise InvalidOperation( + "Cannot compare client equality until both clients are connected" + ) return NotImplemented def __ne__(self, other: Any) -> bool: @@ -1245,13 +1247,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if hasattr(self, "_topology"): + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] + else: + options = [] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1265,9 +1270,7 @@ def option_repr(option: str, value: Any) -> str: return ", ".join(options) def __repr__(self) -> str: - if hasattr(self, "_topology"): - return f"{type(self).__name__}({self._repr_helper()})" - raise InvalidOperation("Cannot perform operation until client is connected") + return f"{type(self).__name__}({self._repr_helper()})" def __getattr__(self, name: str) -> database.AsyncDatabase[_DocumentType]: """Get a database by name. diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index d7f87b718a..870aaef079 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -25,6 +25,7 @@ from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum +from pymongo.asynchronous.srv_resolver import _SrvResolver from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _async_create_lock @@ -33,7 +34,6 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py new file mode 100644 index 0000000000..2e9506bc91 --- /dev/null +++ b/pymongo/asynchronous/srv_resolver.py @@ -0,0 +1,155 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for resolving hosts and options from mongodb+srv:// URIs.""" +from __future__ import annotations + +import ipaddress +import random +from typing import TYPE_CHECKING, Any, Optional, Union + +from pymongo.common import CONNECT_TIMEOUT +from pymongo.errors import ConfigurationError + +if TYPE_CHECKING: + from dns import resolver + +_IS_SYNC = False + + +def _have_dnspython() -> bool: + try: + import dns # noqa: F401 + + return True + except ImportError: + return False + + +# dnspython can return bytes or str from various parts +# of its API depending on version. We always want str. +def maybe_decode(text: Union[str, bytes]) -> str: + if isinstance(text, bytes): + return text.decode() + return text + + +# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. +async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: + if _IS_SYNC: + from dns import resolver + + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns.asyncresolver import Resolver + + return await Resolver.resolve(*args, **kwargs) + + +_INVALID_HOST_MSG = ( + "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " + "Did you mean to use 'mongodb://'?" +) + + +class _SrvResolver: + def __init__( + self, + fqdn: str, + connect_timeout: Optional[float], + srv_service_name: str, + srv_max_hosts: int = 0, + ): + self.__fqdn = fqdn + self.__srv = srv_service_name + self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT + self.__srv_max_hosts = srv_max_hosts or 0 + # Validate the fully qualified domain name. + try: + ipaddress.ip_address(fqdn) + raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) + except ValueError: + pass + + try: + self.__plist = self.__fqdn.split(".")[1:] + except Exception: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None + self.__slen = len(self.__plist) + if self.__slen < 2: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) + + async def get_options(self) -> Optional[str]: + from dns import resolver + + try: + results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) + except (resolver.NoAnswer, resolver.NXDOMAIN): + # No TXT records + return None + except Exception as exc: + raise ConfigurationError(str(exc)) from None + if len(results) > 1: + raise ConfigurationError("Only one TXT record is supported") + return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined] + + async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: + try: + results = await _resolve( + "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout + ) + except Exception as exc: + if not encapsulate_errors: + # Raise the original error. + raise + # Else, raise all errors as ConfigurationError. + raise ConfigurationError(str(exc)) from None + return results + + async def _get_srv_response_and_hosts( + self, encapsulate_errors: bool + ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: + results = await self._resolve_uri(encapsulate_errors) + + # Construct address tuples + nodes = [ + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined] + for res in results + ] + + # Validate hosts + for node in nodes: + try: + nlist = node[0].lower().split(".")[1:][-self.__slen :] + except Exception: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_max_hosts: + nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) + return results, nodes + + async def get_hosts(self) -> list[tuple[str, Any]]: + _, nodes = await self._get_srv_response_and_hosts(True) + return nodes + + async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: + results, nodes = await self._get_srv_response_and_hosts(False) + rrset = results.rrset + ttl = rrset.ttl if rrset else 0 + return nodes, ttl diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5308504882..dc7da2d513 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1214,10 +1214,12 @@ def options(self) -> ClientOptions: def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if hasattr(self, "_topology"): + if hasattr(self, "_topology") and hasattr(other, "_topology"): return self._topology == other._topology else: - raise InvalidOperation("Cannot perform operation until client is connected") + raise InvalidOperation( + "Cannot compare client equality until both clients are connected" + ) return NotImplemented def __ne__(self, other: Any) -> bool: @@ -1227,7 +1229,7 @@ def __hash__(self) -> int: if hasattr(self, "_topology"): return hash(self._topology) else: - raise InvalidOperation("Cannot perform operation until client is connected") + raise InvalidOperation("Cannot hash client until it is connected") def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1243,13 +1245,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if hasattr(self, "_topology"): + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] + else: + options = [] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1263,9 +1268,7 @@ def option_repr(option: str, value: Any) -> str: return ", ".join(options) def __repr__(self) -> str: - if hasattr(self, "_topology"): - return f"{type(self).__name__}({self._repr_helper()})" - raise InvalidOperation("Cannot perform operation until client is connected") + return f"{type(self).__name__}({self._repr_helper()})" def __getattr__(self, name: str) -> database.Database[_DocumentType]: """Get a database by name. diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index c39a57c392..4015534cb2 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -33,7 +33,7 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver +from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext diff --git a/pymongo/srv_resolver.py b/pymongo/synchronous/srv_resolver.py similarity index 92% rename from pymongo/srv_resolver.py rename to pymongo/synchronous/srv_resolver.py index 5be6cb98db..844b142ea8 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from dns import resolver +_IS_SYNC = True + def _have_dnspython() -> bool: try: @@ -45,13 +47,18 @@ def maybe_decode(text: Union[str, bytes]) -> str: # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: - from dns import resolver + if _IS_SYNC: + from dns import resolver + + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns.asyncresolver import Resolver - if hasattr(resolver, "resolve"): - # dnspython >= 2 - return resolver.resolve(*args, **kwargs) - # dnspython 1.X - return resolver.query(*args, **kwargs) + return Resolver.resolve(*args, **kwargs) _INVALID_HOST_MSG = ( diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 09f974fe8a..5c28de22e3 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -34,6 +34,7 @@ ) from urllib.parse import unquote_plus +from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver from pymongo.client_options import _parse_ssl_options from pymongo.common import ( INTERNAL_URI_OPTION_NAME_MAP, @@ -43,7 +44,6 @@ get_validated_options, ) from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.srv_resolver import _have_dnspython, _SrvResolver from pymongo.typings import _Address if TYPE_CHECKING: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index fa7222860c..49874e90c1 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -512,13 +512,13 @@ async def test_uri_option_precedence(self): async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.asynchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.asynchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.asynchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index bf7807eb97..47b68411a8 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -28,8 +28,8 @@ import pymongo from pymongo import common +from pymongo.asynchronous.srv_resolver import _have_dnspython from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _have_dnspython _IS_SYNC = False @@ -54,7 +54,9 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -74,14 +76,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -134,7 +136,10 @@ async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) return False await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -144,7 +149,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) diff --git a/test/test_client.py b/test/test_client.py index f6dc6e3c87..8251d7fff8 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -505,13 +505,13 @@ def test_uri_option_precedence(self): def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.synchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.synchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.synchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 6812465074..df802acb43 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -29,7 +29,7 @@ import pymongo from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _have_dnspython +from pymongo.synchronous.srv_resolver import _have_dnspython _IS_SYNC = True @@ -54,7 +54,9 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -74,14 +76,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -134,7 +136,10 @@ def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAI def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) return False wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -144,7 +149,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) From e38c2ad1fd704514f8359ff7ed2cafda409cc98f Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 11:57:12 -0700 Subject: [PATCH 27/56] refactor part 2 --- pymongo/asynchronous/encryption.py | 2 +- pymongo/asynchronous/mongo_client.py | 19 +- pymongo/asynchronous/monitor.py | 6 +- pymongo/asynchronous/srv_resolver.py | 2 +- pymongo/asynchronous/uri_parser.py | 271 ++++++++++++++++++ pymongo/encryption_options.py | 2 +- pymongo/synchronous/encryption.py | 2 +- pymongo/synchronous/mongo_client.py | 13 +- pymongo/synchronous/srv_resolver.py | 2 +- pymongo/synchronous/uri_parser.py | 271 ++++++++++++++++++ .../{uri_parser.py => uri_parser_shared.py} | 250 +--------------- test/__init__.py | 2 +- test/asynchronous/__init__.py | 2 +- test/asynchronous/helpers.py | 2 +- .../test_discovery_and_monitoring.py | 2 +- test/asynchronous/test_dns.py | 3 +- test/auth_aws/test_auth_aws.py | 2 +- test/auth_oidc/test_auth_oidc.py | 2 +- test/helpers.py | 2 +- test/test_default_exports.py | 4 +- test/test_discovery_and_monitoring.py | 2 +- test/test_dns.py | 3 +- test/test_uri_parser.py | 4 +- test/test_uri_spec.py | 2 +- 24 files changed, 585 insertions(+), 287 deletions(-) create mode 100644 pymongo/asynchronous/uri_parser.py create mode 100644 pymongo/synchronous/uri_parser.py rename pymongo/{uri_parser.py => uri_parser_shared.py} (66%) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index ef8d817b2c..92eb2fba37 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -87,7 +87,7 @@ from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1d9e32e58a..5c32edf837 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -59,10 +59,11 @@ cast, ) +import pymongo.asynchronous.uri_parser from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser -from pymongo.asynchronous import client_session, database +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser_shared +from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -114,7 +115,7 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( SRV_SCHEME, _check_options, _handle_option_deprecations, @@ -783,7 +784,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = uri_parser._validate_uri( + res = pymongo.asynchronous.uri_parser._validate_uri( entity, port, validate=True, @@ -799,7 +800,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, self._port)) + seeds.update(uri_parser_shared.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -862,7 +863,7 @@ def __init__( if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - def _resolve_srv(self) -> None: + async def _resolve_srv(self) -> None: keyword_opts = self._resolve_srv_info["keyword_opts"] seeds = set() opts = common._CaseInsensitiveDictionary() @@ -879,7 +880,7 @@ def _resolve_srv(self) -> None: timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout ) - res = uri_parser._parse_srv( + res = await uri_parser._parse_srv( entity, self._port, validate=True, @@ -892,7 +893,7 @@ def _resolve_srv(self) -> None: seeds.update(res["nodelist"]) opts = res["options"] else: - seeds.update(uri_parser.split_hosts(entity, self._port)) + seeds.update(uri_parser_shared.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -1694,7 +1695,7 @@ async def _get_topology(self) -> Topology: """ if not self._opened: if self._resolve_srv_info["is_srv"]: - self._resolve_srv() + await self._resolve_srv() self._init_background(first=True) await self._topology.open() async with self._lock: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 870aaef079..28bc2e4479 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -395,7 +395,7 @@ async def _run(self) -> None: # Don't poll right after creation, wait 60 seconds first if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL: return - seedlist = self._get_seedlist() + seedlist = await self._get_seedlist() if seedlist: self._seedlist = seedlist try: @@ -404,7 +404,7 @@ async def _run(self) -> None: # Topology was garbage-collected. await self.close() - def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: + async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: """Poll SRV records for a seedlist. Returns a list of ServerDescriptions. @@ -415,7 +415,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._settings.pool_options.connect_timeout, self._settings.srv_service_name, ) - seedlist, ttl = resolver.get_hosts_and_min_ttl() + seedlist, ttl = await resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index 2e9506bc91..b7e94a3f06 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -58,7 +58,7 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: else: from dns.asyncresolver import Resolver - return await Resolver.resolve(*args, **kwargs) + return await Resolver.resolve(*args, **kwargs) # type:ignore[return-value] _INVALID_HOST_MSG = ( diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py new file mode 100644 index 0000000000..3953393044 --- /dev/null +++ b/pymongo/asynchronous/uri_parser.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import sys +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + _BAD_DB_CHARS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME, + SRV_SCHEME_LEN, + _check_options, + parse_userinfo, + split_hosts, + split_options, +) + +_IS_SYNC = False + + +async def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + await _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +async def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = await dns_resolver.get_hosts() + dns_options = await dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index a1c40dc7b2..e7437cab0f 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -32,7 +32,7 @@ from bson import int64 from pymongo.common import validate_is_mapping from pymongo.errors import ConfigurationError -from pymongo.uri_parser import _parse_kms_tls_options +from pymongo.uri_parser_shared import _parse_kms_tls_options if TYPE_CHECKING: from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index a97534ed41..47a3b1bf6b 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -86,7 +86,7 @@ _raise_connection_failure, ) from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index dc7da2d513..21e695f348 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -58,9 +58,10 @@ cast, ) +import pymongo.synchronous.uri_parser from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser_shared from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -97,7 +98,7 @@ from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database +from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession @@ -113,7 +114,7 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( SRV_SCHEME, _check_options, _handle_option_deprecations, @@ -781,7 +782,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = uri_parser._validate_uri( + res = pymongo.synchronous.uri_parser._validate_uri( entity, port, validate=True, @@ -797,7 +798,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, self._port)) + seeds.update(uri_parser_shared.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -890,7 +891,7 @@ def _resolve_srv(self) -> None: seeds.update(res["nodelist"]) opts = res["options"] else: - seeds.update(uri_parser.split_hosts(entity, self._port)) + seeds.update(uri_parser_shared.split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index 844b142ea8..c424e6e9af 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -58,7 +58,7 @@ def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: else: from dns.asyncresolver import Resolver - return Resolver.resolve(*args, **kwargs) + return Resolver.resolve(*args, **kwargs) # type:ignore[return-value] _INVALID_HOST_MSG = ( diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py new file mode 100644 index 0000000000..a9bd6986b7 --- /dev/null +++ b/pymongo/synchronous/uri_parser.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import sys +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.synchronous.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + _BAD_DB_CHARS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME, + SRV_SCHEME_LEN, + _check_options, + parse_userinfo, + split_hosts, + split_options, +) + +_IS_SYNC = True + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser_shared.py similarity index 66% rename from pymongo/uri_parser.py rename to pymongo/uri_parser_shared.py index 5c28de22e3..17ee2b187d 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser_shared.py @@ -34,11 +34,10 @@ ) from urllib.parse import unquote_plus -from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.client_options import _parse_ssl_options from pymongo.common import ( INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary, get_validated_options, @@ -421,253 +420,6 @@ def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) - result.update( - _parse_srv( - uri, - default_port, - validate, - warn, - normalize, - connect_timeout, - srv_service_name, - srv_max_hosts, - ) - ) - return result - - -def _validate_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_srv( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - else: - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, _ = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - _, _, hosts = host_part.rpartition("@") - else: - hosts = host_part - - hosts = unquote_plus(hosts) - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - nodes = split_hosts(hosts, default_port=None) - fqdn, port = nodes[0] - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "options": options, - } - - def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: """Parse KMS TLS connection options.""" if not kms_tls_options: diff --git a/test/__init__.py b/test/__init__.py index 307780271d..8e362de5ad 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -32,7 +32,7 @@ import warnings from asyncio import iscoroutinefunction -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri try: import ipaddress diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index f03fcf4eeb..885eb70313 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -32,7 +32,7 @@ import warnings from asyncio import iscoroutinefunction -from pymongo.uri_parser import parse_uri +from pymongo.asynchronous.uri_parser import parse_uri try: import ipaddress diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 98e00e9385..c1cc57b2bc 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -45,9 +45,9 @@ from bson.son import SON from pymongo import common, message +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index b3de2c5a4d..46c111b258 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -54,6 +54,7 @@ from pymongo import common, monitoring from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -66,7 +67,6 @@ from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.uri_parser import parse_uri _IS_SYNC = False diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index bc6c26be31..ecef7251f0 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -31,9 +31,10 @@ ) from test.utils_shared import async_wait_until +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser_shared import split_hosts _IS_SYNC = False diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index a7660f2f67..db17de99d2 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -31,8 +31,8 @@ pass from pymongo import MongoClient +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.errors import OperationFailure -from pymongo.uri_parser import parse_uri pytestmark = pytest.mark.auth_aws diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 6dc36dc8a4..f3c516cb35 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -37,6 +37,7 @@ from pymongo import MongoClient from pymongo._azure_helpers import _get_azure_response from pymongo._gcp_helpers import _get_gcp_response +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.auth_oidc_shared import _get_k8s_token from pymongo.auth_shared import _build_credentials_tuple from pymongo.cursor_shared import CursorType @@ -49,7 +50,6 @@ OIDCCallbackResult, _get_authenticator, ) -from pymongo.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" diff --git a/test/helpers.py b/test/helpers.py index 627be182b5..12c55ade1b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -47,7 +47,7 @@ from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/test_default_exports.py b/test/test_default_exports.py index d9301d2223..9035414c75 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -69,6 +69,7 @@ def test_bson(self): def test_pymongo_imports(self): import pymongo + from pymongo.asynchronous.uri_parser import parse_uri from pymongo.auth import MECHANISMS from pymongo.auth_oidc import ( OIDCCallback, @@ -198,10 +199,9 @@ def test_pymongo_imports(self): from pymongo.server_api import ServerApi, ServerApiVersion from pymongo.server_description import ServerDescription from pymongo.topology_description import TopologyDescription - from pymongo.uri_parser import ( + from pymongo.uri_parser_shared import ( parse_host, parse_ipv6_literal_host, - parse_uri, parse_userinfo, split_hosts, split_options, diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 00021310c9..07720473ca 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -65,8 +65,8 @@ from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext +from pymongo.synchronous.uri_parser import parse_uri from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.uri_parser import parse_uri _IS_SYNC = True diff --git a/test/test_dns.py b/test/test_dns.py index 045393832f..0290eb16d9 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -33,7 +33,8 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser_shared import split_hosts _IS_SYNC = True diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index f95717e95f..85021cbf7e 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -27,9 +27,9 @@ from bson.binary import JAVA_LEGACY from pymongo import ReadPreference +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.uri_parser import ( - parse_uri, +from pymongo.uri_parser_shared import ( parse_userinfo, split_hosts, split_options, diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 29cde7e078..3a1ddd8352 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -27,9 +27,9 @@ from test import unittest from test.helpers import clear_warning_registry +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _have_snappy -from pymongo.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") From 94fec4452f04f3508ed2006665dce0195080463c Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 12:15:54 -0700 Subject: [PATCH 28/56] fix circular import --- pymongo/asynchronous/uri_parser.py | 15 +++++++++++++++ pymongo/synchronous/uri_parser.py | 15 +++++++++++++++ pymongo/uri_parser_shared.py | 12 ------------ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index 3953393044..4fa74fe869 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -269,3 +269,18 @@ async def _parse_srv( "nodelist": nodes, "options": options, } + + +if __name__ == "__main__": + import pprint + + try: + if _IS_SYNC: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + else: + import asyncio + + pprint.pprint(asyncio.run(parse_uri(sys.argv[1]))) # type:ignore[arg-type] # noqa: T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index a9bd6986b7..8990cf7e44 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -269,3 +269,18 @@ def _parse_srv( "nodelist": nodes, "options": options, } + + +if __name__ == "__main__": + import pprint + + try: + if _IS_SYNC: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + else: + import asyncio + + pprint.pprint(asyncio.run(parse_uri(sys.argv[1]))) # type:ignore[arg-type] # noqa: T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py index 17ee2b187d..32eabaea73 100644 --- a/pymongo/uri_parser_shared.py +++ b/pymongo/uri_parser_shared.py @@ -20,7 +20,6 @@ from __future__ import annotations import re -import sys import warnings from typing import ( TYPE_CHECKING, @@ -34,7 +33,6 @@ ) from urllib.parse import unquote_plus -from pymongo.asynchronous.uri_parser import parse_uri from pymongo.client_options import _parse_ssl_options from pymongo.common import ( INTERNAL_URI_OPTION_NAME_MAP, @@ -451,13 +449,3 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict raise ConfigurationError(f"Insecure TLS options prohibited: {n}") contexts[provider] = ssl_context return contexts - - -if __name__ == "__main__": - import pprint - - try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) From 32fabb9e05f7464d6f490cbc35713e82e8de1fb0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 12:28:12 -0700 Subject: [PATCH 29/56] fix tests --- pymongo/asynchronous/srv_resolver.py | 4 ++-- pymongo/synchronous/srv_resolver.py | 4 ++-- test/asynchronous/test_client.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index b7e94a3f06..da0b6bd032 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -56,9 +56,9 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: # dnspython 1.X return resolver.query(*args, **kwargs) else: - from dns.asyncresolver import Resolver + from dns import asyncresolver - return await Resolver.resolve(*args, **kwargs) # type:ignore[return-value] + return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] _INVALID_HOST_MSG = ( diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index c424e6e9af..a29c36c586 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -56,9 +56,9 @@ def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: # dnspython 1.X return resolver.query(*args, **kwargs) else: - from dns.asyncresolver import Resolver + from dns import asyncresolver - return Resolver.resolve(*args, **kwargs) # type:ignore[return-value] + return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] _INVALID_HOST_MSG = ( diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 49874e90c1..37b40f8d22 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -607,7 +607,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -629,7 +629,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ From af568da39143d0f4dfc60037cf24dd006cdcefea Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 12:36:11 -0700 Subject: [PATCH 30/56] fix test and repr --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- test/asynchronous/test_discovery_and_monitoring.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 5c32edf837..e76cb16613 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1257,7 +1257,7 @@ def option_repr(option: str, value: Any) -> str: ] ] else: - options = [] + options = ["host={self._host}", "port={self._port}"] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 21e695f348..0ffbda5fa3 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1255,7 +1255,7 @@ def option_repr(option: str, value: Any) -> str: ] ] else: - options = [] + options = ["host={self._host}", "port={self._port}"] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 46c111b258..fa62b25dd1 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -81,7 +81,7 @@ async def create_mock_topology(uri, monitor_class=DummyMonitor): - parsed_uri = parse_uri(uri) + parsed_uri = await parse_uri(uri) replica_set_name = None direct_connection = None load_balanced = None From 82bcd3858c2e4ffa255eb80fd6f919a3ca0758fa Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 12:49:55 -0700 Subject: [PATCH 31/56] fix test --- test/asynchronous/test_srv_polling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index 47b68411a8..3dcd21ef1d 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -61,9 +61,9 @@ def enable(self): if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval - def mock_get_hosts_and_min_ttl(resolver, *args): + async def mock_get_hosts_and_min_ttl(resolver, *args): assert self.old_dns_resolver_response is not None - nodes, ttl = self.old_dns_resolver_response(resolver) + nodes, ttl = await self.old_dns_resolver_response(resolver) if self.nodelist_callback is not None: nodes = self.nodelist_callback() if self.ttl_time is not None: From 60bf17d4ade429742ff2c3fd4c3fd71667849589 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 12:58:29 -0700 Subject: [PATCH 32/56] fix import for test --- test/test_uri_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 85021cbf7e..0baefa0c3a 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -27,8 +27,8 @@ from bson.binary import JAVA_LEGACY from pymongo import ReadPreference -from pymongo.asynchronous.uri_parser import parse_uri from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.synchronous.uri_parser import parse_uri from pymongo.uri_parser_shared import ( parse_userinfo, split_hosts, From 63ba7bec51bdcd973126636fcc409217aa49f65e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 13:02:57 -0700 Subject: [PATCH 33/56] change helpers import --- test/asynchronous/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index c1cc57b2bc..7b021e8b44 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -45,9 +45,9 @@ from bson.son import SON from pymongo import common, message -from pymongo.asynchronous.uri_parser import parse_uri from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl From 7585e04ed327eb3d7357299465875fdc414c2bf2 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 13:39:06 -0700 Subject: [PATCH 34/56] fix uri_parser --- pymongo/uri_parser.py | 31 +++++++++++++++++++++++++++++++ test/test_uri_spec.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 pymongo/uri_parser.py diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py new file mode 100644 index 0000000000..0ab300a1b6 --- /dev/null +++ b/pymongo/uri_parser.py @@ -0,0 +1,31 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Re-import of synchronous Uri Parser API for compatibility.""" +from __future__ import annotations + +from pymongo.synchronous.uri_parser import * # noqa: F403 +from pymongo.synchronous.uri_parser import __doc__ as original_doc +from pymongo.uri_parser_shared import * # noqa: F403 + +__doc__ = original_doc +__all__ = [ # noqa: F405 + "parse_userinfo", + "parse_ipv6_literal_host", + "parse_host", + "validate_options", + "split_options", + "split_hosts", + "parse_uri", +] diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 3a1ddd8352..aeb0be94b5 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -27,9 +27,9 @@ from test import unittest from test.helpers import clear_warning_registry -from pymongo.asynchronous.uri_parser import parse_uri from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _have_snappy +from pymongo.synchronous.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") From 2c69412f0be1247c14c9cec483ead2be16dafd71 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 13:53:50 -0700 Subject: [PATCH 35/56] fix srv_resolver --- pymongo/srv_resolver.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 pymongo/srv_resolver.py diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py new file mode 100644 index 0000000000..d4ede0cd6b --- /dev/null +++ b/pymongo/srv_resolver.py @@ -0,0 +1,22 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Re-import of synchronous Uri Parser API for compatibility.""" +from __future__ import annotations + +from pymongo.synchronous.srv_resolver import __doc__ as original_doc +from pymongo.synchronous.srv_resolver import _SrvResolver, maybe_decode + +__doc__ = original_doc +__all__ = ["maybe_decode", "_SrvResolver"] From f834b89e146c161dd657790eb80131c900442334 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 14:04:06 -0700 Subject: [PATCH 36/56] add missing awaits --- test/asynchronous/__init__.py | 4 ++-- test/asynchronous/test_dns.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 885eb70313..b3f65e5d3c 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1027,7 +1027,7 @@ async def _unmanaged_async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] @@ -1058,7 +1058,7 @@ async def _async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index ecef7251f0..17d03b2043 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -162,7 +162,7 @@ async def run_test(self): # and re-run these assertions. else: try: - parse_uri(uri) + await parse_uri(uri) except (ConfigurationError, ValueError): pass else: From d4504577d355dad8d137e07a9ac0ee5d5c7165b0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 13 Mar 2025 14:49:39 -0700 Subject: [PATCH 37/56] add missing await --- test/asynchronous/test_dns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index 17d03b2043..d0e801e123 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -110,7 +110,7 @@ async def run_test(self): hosts = frozenset(split_hosts(",".join(hosts))) if seeds or num_seeds: - result = parse_uri(uri, validate=True) + result = await parse_uri(uri, validate=True) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: From 2a8b1b27e5344ed498c89318f43bf3c39c414886 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:31:06 -0700 Subject: [PATCH 38/56] Update test/auth_aws/test_auth_aws.py Co-authored-by: Noah Stapp --- test/auth_aws/test_auth_aws.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index db17de99d2..df1b285d8e 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -31,7 +31,7 @@ pass from pymongo import MongoClient -from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri from pymongo.errors import OperationFailure pytestmark = pytest.mark.auth_aws From c82cf500adfb21698edfa42cb38af8dbb56b0d4d Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:33:17 -0700 Subject: [PATCH 39/56] Update test/asynchronous/helpers.py Co-authored-by: Noah Stapp --- test/asynchronous/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 7b021e8b44..1d1545985e 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -47,7 +47,7 @@ from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.asynchronous.uri_parser import parse_uri if HAVE_SSL: import ssl From 925680880371b8a3e9b8d8e4dd24c42237fa3908 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:33:28 -0700 Subject: [PATCH 40/56] Update test/auth_oidc/test_auth_oidc.py Co-authored-by: Noah Stapp --- test/auth_oidc/test_auth_oidc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index f3c516cb35..d6a07dd1f5 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -37,7 +37,7 @@ from pymongo import MongoClient from pymongo._azure_helpers import _get_azure_response from pymongo._gcp_helpers import _get_gcp_response -from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri from pymongo.auth_oidc_shared import _get_k8s_token from pymongo.auth_shared import _build_credentials_tuple from pymongo.cursor_shared import CursorType From c6d2cebd827b474e03f524791ae6fdc9c39a4ce8 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 14 Mar 2025 10:53:10 -0700 Subject: [PATCH 41/56] address comments --- pymongo/asynchronous/mongo_client.py | 11 +-- pymongo/asynchronous/srv_resolver.py | 5 +- pymongo/asynchronous/uri_parser.py | 106 +-------------------------- pymongo/synchronous/mongo_client.py | 11 +-- pymongo/synchronous/srv_resolver.py | 5 +- pymongo/synchronous/uri_parser.py | 106 +-------------------------- pymongo/uri_parser_shared.py | 98 +++++++++++++++++++++++++ 7 files changed, 122 insertions(+), 220 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index e76cb16613..f14a93e3d5 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -59,10 +59,10 @@ cast, ) -import pymongo.asynchronous.uri_parser +import pymongo.uri_parser_shared from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser_shared +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk @@ -121,6 +121,7 @@ _handle_option_deprecations, _handle_security_options, _normalize_options, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -784,7 +785,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = pymongo.asynchronous.uri_parser._validate_uri( + res = pymongo.uri_parser_shared._validate_uri( entity, port, validate=True, @@ -800,7 +801,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser_shared.split_hosts(entity, self._port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -893,7 +894,7 @@ async def _resolve_srv(self) -> None: seeds.update(res["nodelist"]) opts = res["options"] else: - seeds.update(uri_parser_shared.split_hosts(entity, self._port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index da0b6bd032..b5adf6c920 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -58,7 +58,10 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: else: from dns import asyncresolver - return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError("Upgrade to dnspython version >= 2.0") _INVALID_HOST_MSG = ( diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index 4fa74fe869..9477572940 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -4,19 +4,17 @@ from typing import Any, Optional from urllib.parse import unquote_plus -from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.asynchronous.srv_resolver import _SrvResolver from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary from pymongo.errors import ConfigurationError, InvalidURI from pymongo.uri_parser_shared import ( _ALLOWED_TXT_OPTS, - _BAD_DB_CHARS, DEFAULT_PORT, SCHEME, SCHEME_LEN, - SRV_SCHEME, SRV_SCHEME_LEN, _check_options, - parse_userinfo, + _validate_uri, split_hosts, split_options, ) @@ -103,102 +101,6 @@ async def parse_uri( return result -def _validate_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - async def _parse_srv( uri: str, default_port: Optional[int] = DEFAULT_PORT, @@ -277,10 +179,6 @@ async def _parse_srv( try: if _IS_SYNC: pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - else: - import asyncio - - pprint.pprint(asyncio.run(parse_uri(sys.argv[1]))) # type:ignore[arg-type] # noqa: T203 except InvalidURI as exc: print(exc) # noqa: T201 sys.exit(0) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 0ffbda5fa3..dfb813f025 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -58,10 +58,10 @@ cast, ) -import pymongo.synchronous.uri_parser +import pymongo.uri_parser_shared from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser_shared +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -120,6 +120,7 @@ _handle_option_deprecations, _handle_security_options, _normalize_options, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -782,7 +783,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = pymongo.synchronous.uri_parser._validate_uri( + res = pymongo.uri_parser_shared._validate_uri( entity, port, validate=True, @@ -798,7 +799,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser_shared.split_hosts(entity, self._port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -891,7 +892,7 @@ def _resolve_srv(self) -> None: seeds.update(res["nodelist"]) opts = res["options"] else: - seeds.update(uri_parser_shared.split_hosts(entity, self._port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index a29c36c586..486c7e6522 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -58,7 +58,10 @@ def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: else: from dns import asyncresolver - return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError("Upgrade to dnspython version >= 2.0") _INVALID_HOST_MSG = ( diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index 8990cf7e44..856004ea6d 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -6,17 +6,15 @@ from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.synchronous.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.synchronous.srv_resolver import _SrvResolver from pymongo.uri_parser_shared import ( _ALLOWED_TXT_OPTS, - _BAD_DB_CHARS, DEFAULT_PORT, SCHEME, SCHEME_LEN, - SRV_SCHEME, SRV_SCHEME_LEN, _check_options, - parse_userinfo, + _validate_uri, split_hosts, split_options, ) @@ -103,102 +101,6 @@ def parse_uri( return result -def _validate_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - def _parse_srv( uri: str, default_port: Optional[int] = DEFAULT_PORT, @@ -277,10 +179,6 @@ def _parse_srv( try: if _IS_SYNC: pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - else: - import asyncio - - pprint.pprint(asyncio.run(parse_uri(sys.argv[1]))) # type:ignore[arg-type] # noqa: T203 except InvalidURI as exc: print(exc) # noqa: T201 sys.exit(0) diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py index 32eabaea73..0b2f4e7e07 100644 --- a/pymongo/uri_parser_shared.py +++ b/pymongo/uri_parser_shared.py @@ -20,6 +20,7 @@ from __future__ import annotations import re +import sys import warnings from typing import ( TYPE_CHECKING, @@ -33,6 +34,7 @@ ) from urllib.parse import unquote_plus +from pymongo.asynchronous.srv_resolver import _have_dnspython from pymongo.client_options import _parse_ssl_options from pymongo.common import ( INTERNAL_URI_OPTION_NAME_MAP, @@ -449,3 +451,99 @@ def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict raise ConfigurationError(f"Insecure TLS options prohibited: {n}") contexts[provider] = ssl_context return contexts + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } From 8927a276aa95a57e92c49ff40848610b8e4bd099 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 14 Mar 2025 10:58:52 -0700 Subject: [PATCH 42/56] undo import change in helpers --- test/asynchronous/helpers.py | 2 +- test/auth_aws/test_auth_aws.py | 2 +- test/auth_oidc/test_auth_oidc.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 1d1545985e..7b021e8b44 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -47,7 +47,7 @@ from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index df1b285d8e..9738694d85 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -31,8 +31,8 @@ pass from pymongo import MongoClient -from pymongo.synchronous.uri_parser import parse_uri from pymongo.errors import OperationFailure +from pymongo.synchronous.uri_parser import parse_uri pytestmark = pytest.mark.auth_aws diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index d6a07dd1f5..0c8431a1e8 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -37,7 +37,6 @@ from pymongo import MongoClient from pymongo._azure_helpers import _get_azure_response from pymongo._gcp_helpers import _get_gcp_response -from pymongo.synchronous.uri_parser import parse_uri from pymongo.auth_oidc_shared import _get_k8s_token from pymongo.auth_shared import _build_credentials_tuple from pymongo.cursor_shared import CursorType @@ -50,6 +49,7 @@ OIDCCallbackResult, _get_authenticator, ) +from pymongo.synchronous.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" From 76a68b2a2ead521bfc78c7b47de2fc117cd6d388 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 17 Mar 2025 16:39:10 -0700 Subject: [PATCH 43/56] change client eq and hash --- pymongo/asynchronous/mongo_client.py | 29 ++++++++++++++++++---------- pymongo/synchronous/mongo_client.py | 29 ++++++++++++++++++---------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f14a93e3d5..bc7f4e001b 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -847,7 +847,9 @@ def __init__( "username": username, "password": password, "dbase": dbase, + "seeds": seeds, "fqdn": fqdn, + "srv_service_name": srv_service_name, "pool_class": pool_class, "monitor_class": monitor_class, "condition_class": condition_class, @@ -856,6 +858,13 @@ def __init__( if not is_srv: self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + self._opened = False self._closed = False if not is_srv: @@ -937,12 +946,6 @@ def _init_based_on_options( self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any ) -> None: self._event_listeners = self._options.pool_options._event_listeners - super().__init__( - self._options.codec_options, - self._options.read_preference, - self._options.write_concern, - self._options.read_concern, - ) self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, @@ -1216,14 +1219,20 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self): + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): if hasattr(self, "_topology") and hasattr(other, "_topology"): return self._topology == other._topology else: - raise InvalidOperation( - "Cannot compare client equality until both clients are connected" - ) + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: @@ -1233,7 +1242,7 @@ def __hash__(self) -> int: if hasattr(self, "_topology"): return hash(self._topology) else: - raise InvalidOperation("Cannot hash client until it is connected") + raise hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index dfb813f025..0d5dd1d046 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -845,7 +845,9 @@ def __init__( "username": username, "password": password, "dbase": dbase, + "seeds": seeds, "fqdn": fqdn, + "srv_service_name": srv_service_name, "pool_class": pool_class, "monitor_class": monitor_class, "condition_class": condition_class, @@ -854,6 +856,13 @@ def __init__( if not is_srv: self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, + ) + self._opened = False self._closed = False if not is_srv: @@ -935,12 +944,6 @@ def _init_based_on_options( self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any ) -> None: self._event_listeners = self._options.pool_options._event_listeners - super().__init__( - self._options.codec_options, - self._options.read_preference, - self._options.write_concern, - self._options.read_concern, - ) self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=self._options.replica_set_name, @@ -1214,14 +1217,20 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self): + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): if hasattr(self, "_topology") and hasattr(other, "_topology"): return self._topology == other._topology else: - raise InvalidOperation( - "Cannot compare client equality until both clients are connected" - ) + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: @@ -1231,7 +1240,7 @@ def __hash__(self) -> int: if hasattr(self, "_topology"): return hash(self._topology) else: - raise InvalidOperation("Cannot hash client until it is connected") + raise hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: From b60eb60d0a225befa312148a6a86bf664723a54c Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 17 Mar 2025 16:56:39 -0700 Subject: [PATCH 44/56] address comments part 1 --- pymongo/asynchronous/mongo_client.py | 21 +++++++++++---------- pymongo/asynchronous/uri_parser.py | 12 ------------ pymongo/srv_resolver.py | 2 +- pymongo/synchronous/mongo_client.py | 21 +++++++++++---------- pymongo/synchronous/uri_parser.py | 12 ------------ pymongo/uri_parser.py | 16 ++++++++++++++-- 6 files changed, 37 insertions(+), 47 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bc7f4e001b..f4964942e4 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -757,6 +757,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port + self._topology = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1229,20 +1230,20 @@ def eq_props(self): def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if hasattr(self, "_topology") and hasattr(other, "_topology"): - return self._topology == other._topology - else: + if self._topology is None: return self.eq_props() == other.eq_props() + else: + return self._topology == other._topology return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - if hasattr(self, "_topology"): - return hash(self._topology) - else: + if self._topology is None: raise hash(self.eq_props()) + else: + return hash(self._topology) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1258,7 +1259,9 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - if hasattr(self, "_topology"): + if self._topology is None: + options = ["host={self._host}", "port={self._port}"] + else: options = [ "host=%r" % [ @@ -1266,8 +1269,6 @@ def option_repr(option: str, value: Any) -> str: for host, port in self._topology_settings.seeds ] ] - else: - options = ["host={self._host}", "port={self._port}"] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1673,7 +1674,7 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if hasattr(self, "_topology"): + if self._topology is not None: session_ids = self._topology.pop_all_sessions() if session_ids: await self._end_sessions(session_ids) diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index 9477572940..0cc69247eb 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import Any, Optional from urllib.parse import unquote_plus @@ -171,14 +170,3 @@ async def _parse_srv( "nodelist": nodes, "options": options, } - - -if __name__ == "__main__": - import pprint - - try: - if _IS_SYNC: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index d4ede0cd6b..1414a86b05 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2025-present MongoDB, Inc. +# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 0d5dd1d046..b82f405722 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -755,6 +755,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port + self._topology = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1227,20 +1228,20 @@ def eq_props(self): def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if hasattr(self, "_topology") and hasattr(other, "_topology"): - return self._topology == other._topology - else: + if self._topology is None: return self.eq_props() == other.eq_props() + else: + return self._topology == other._topology return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - if hasattr(self, "_topology"): - return hash(self._topology) - else: + if self._topology is None: raise hash(self.eq_props()) + else: + return hash(self._topology) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1256,7 +1257,9 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - if hasattr(self, "_topology"): + if self._topology is None: + options = ["host={self._host}", "port={self._port}"] + else: options = [ "host=%r" % [ @@ -1264,8 +1267,6 @@ def option_repr(option: str, value: Any) -> str: for host, port in self._topology_settings.seeds ] ] - else: - options = ["host={self._host}", "port={self._port}"] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1667,7 +1668,7 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if hasattr(self, "_topology"): + if self._topology is not None: session_ids = self._topology.pop_all_sessions() if session_ids: self._end_sessions(session_ids) diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index 856004ea6d..2810f0f123 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import Any, Optional from urllib.parse import unquote_plus @@ -171,14 +170,3 @@ def _parse_srv( "nodelist": nodes, "options": options, } - - -if __name__ == "__main__": - import pprint - - try: - if _IS_SYNC: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 0ab300a1b6..19a41fd9e8 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -1,4 +1,4 @@ -# Copyright 2025-present MongoDB, Inc. +# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous Uri Parser API for compatibility.""" +"""Re-import of synchronous URI Parser API for compatibility.""" from __future__ import annotations +import sys + +from pymongo.errors import InvalidURI from pymongo.synchronous.uri_parser import * # noqa: F403 from pymongo.synchronous.uri_parser import __doc__ as original_doc from pymongo.uri_parser_shared import * # noqa: F403 @@ -29,3 +32,12 @@ "split_hosts", "parse_uri", ] + +if __name__ == "__main__": + import pprint + + try: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) From 0b6d303a6c39ebcff540b928458fca01d56f0c8c Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 17 Mar 2025 17:01:11 -0700 Subject: [PATCH 45/56] address comment ish - remove first --- pymongo/asynchronous/mongo_client.py | 8 ++++---- pymongo/synchronous/mongo_client.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f4964942e4..e4a1464b29 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -869,7 +869,7 @@ def __init__( self._opened = False self._closed = False if not is_srv: - self._init_background(first=True) + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -1000,9 +1000,9 @@ async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" await self._get_topology() - def _init_background(self, old_pid: Optional[int] = None, first: bool = False) -> None: + def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) - if first and _HAS_REGISTER_AT_FORK: + if _HAS_REGISTER_AT_FORK: # Add this client to the list of weakly referenced items. # This will be used later if we fork. AsyncMongoClient._clients[self._topology._topology_id] = self @@ -1707,7 +1707,7 @@ async def _get_topology(self) -> Topology: if not self._opened: if self._resolve_srv_info["is_srv"]: await self._resolve_srv() - self._init_background(first=True) + self._init_background() await self._topology.open() async with self._lock: self._kill_cursors_executor.open() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index b82f405722..3367366c26 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -867,7 +867,7 @@ def __init__( self._opened = False self._closed = False if not is_srv: - self._init_background(first=True) + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] @@ -998,9 +998,9 @@ def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" self._get_topology() - def _init_background(self, old_pid: Optional[int] = None, first: bool = False) -> None: + def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) - if first and _HAS_REGISTER_AT_FORK: + if _HAS_REGISTER_AT_FORK: # Add this client to the list of weakly referenced items. # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self @@ -1701,7 +1701,7 @@ def _get_topology(self) -> Topology: if not self._opened: if self._resolve_srv_info["is_srv"]: self._resolve_srv() - self._init_background(first=True) + self._init_background() self._topology.open() with self._lock: self._kill_cursors_executor.open() From 0ca6afd5a5ec13b0932f1d259f508a52ffd3a870 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 18 Mar 2025 00:01:51 -0700 Subject: [PATCH 46/56] re-order call to super's init --- pymongo/asynchronous/mongo_client.py | 5 +++-- pymongo/synchronous/mongo_client.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index e4a1464b29..138cbcfb25 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -856,8 +856,6 @@ def __init__( "condition_class": condition_class, } ) - if not is_srv: - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) super().__init__( self._options.codec_options, @@ -866,6 +864,9 @@ def __init__( self._options.read_concern, ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False if not is_srv: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3367366c26..db439c847f 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -854,8 +854,6 @@ def __init__( "condition_class": condition_class, } ) - if not is_srv: - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) super().__init__( self._options.codec_options, @@ -864,6 +862,9 @@ def __init__( self._options.read_concern, ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False if not is_srv: From d616135b243b0851745ec858b396339a6cb18185 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 18 Mar 2025 08:45:26 -0700 Subject: [PATCH 47/56] update link to use https based on prev commit on main --- pymongo/uri_parser_shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py index 0b2f4e7e07..e7ba4c9fb5 100644 --- a/pymongo/uri_parser_shared.py +++ b/pymongo/uri_parser_shared.py @@ -4,7 +4,7 @@ # may not use this file except in compliance with the License. You # may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -163,7 +163,7 @@ def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Addr port = int(port) # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 + # https://tools.ietf.org/html/rfc4343 # This prevents useless rediscovery if "foo.com" is in the seed list but # "FOO.com" is in the hello response. return host.lower(), port From 379dfb66a4d59cb343baf3bce413079d902fabf4 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 18 Mar 2025 09:28:42 -0700 Subject: [PATCH 48/56] fix typing --- pymongo/asynchronous/client_session.py | 1 + pymongo/asynchronous/mongo_client.py | 22 +++++++++++++++++++--- pymongo/asynchronous/pool.py | 1 + pymongo/synchronous/client_session.py | 1 + pymongo/synchronous/mongo_client.py | 22 +++++++++++++++++++--- pymongo/synchronous/pool.py | 1 + 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 98dd6a4706..1c3d0b6732 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -996,6 +996,7 @@ def _txn_read_preference(self) -> Optional[_ServerMode]: def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: if isinstance(self._server_session, _EmptyServerSession): old = self._server_session + assert self._client._topology is not None self._server_session = self._client._topology.get_server_session( logical_session_timeout_minutes ) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index c1e5f376e5..1cc7298c3a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -757,7 +757,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - self._topology = None + self._topology: Optional[Topology] = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1036,6 +1036,7 @@ def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[ def _after_fork(self) -> None: """Resets topology in a child after successfully forking.""" + assert self._topology is not None self._init_background(self._topology._pid) # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() @@ -1195,6 +1196,7 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + assert self._topology is not None return self._topology.description @property @@ -1208,6 +1210,7 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + assert self._topology is not None description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1221,7 +1224,7 @@ def options(self) -> ClientOptions: """ return self._options - def eq_props(self): + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: return ( tuple(sorted(self._resolve_srv_info["seeds"])), self._options.replica_set_name, @@ -1242,7 +1245,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: if self._topology is None: - raise hash(self.eq_props()) + return hash(self.eq_props()) else: return hash(self._topology) @@ -1387,6 +1390,7 @@ def _ensure_session( def _send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession] ) -> None: + assert self._topology is not None topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: @@ -1572,6 +1576,7 @@ async def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + assert self._topology is not None topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1594,6 +1599,7 @@ async def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + assert self._topology is not None return await self._topology.get_primary() # type: ignore[return-value] @property @@ -1607,6 +1613,7 @@ async def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ + assert self._topology is not None return await self._topology.get_secondaries() @property @@ -1617,6 +1624,7 @@ async def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + assert self._topology is not None return await self._topology.get_arbiters() @property @@ -1709,10 +1717,12 @@ async def _get_topology(self) -> Topology: if self._resolve_srv_info["is_srv"]: await self._resolve_srv() self._init_background() + assert self._topology is not None await self._topology.open() async with self._lock: self._kill_cursors_executor.open() self._opened = True + assert self._topology is not None return self._topology @contextlib.asynccontextmanager @@ -1815,6 +1825,7 @@ async def _conn_from_server( # Thread safe: if the type is single it cannot change. # NOTE: We already opened the Topology when selecting a server so there's no need # to call _get_topology() again. + assert self._topology is not None single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single async with self._checkout(server, session) as conn: if single: @@ -2154,6 +2165,7 @@ async def _process_kill_cursors(self) -> None: """Process any pending kill cursors requests.""" address_to_cursor_ids = defaultdict(list) pinned_cursors = [] + assert self._topology is not None # Other threads or the GC may append to the queue concurrently. while True: @@ -2195,6 +2207,7 @@ async def _process_periodic_tasks(self) -> None: """Process any pending kill cursors requests and maintain connection pool parameters. """ + assert self._topology is not None try: await self._process_kill_cursors() await self._topology.update_pool() @@ -2210,6 +2223,7 @@ def _return_server_session( """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): return None + assert self._topology is not None return self._topology.return_server_session(server_session) @contextlib.asynccontextmanager @@ -2247,6 +2261,7 @@ async def _tmp_session( async def _process_response( self, reply: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> None: + assert self._topology is not None await self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) @@ -2638,6 +2653,7 @@ async def handle( self.completed_handshake, self.service_id, ) + assert self._client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) async def __aenter__(self) -> _MongoClientErrorHandler: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index d06c528e78..215df35fb2 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -1315,6 +1315,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A raise if handler: + assert handler.client._topology is not None await handler.client._topology.receive_cluster_time(conn._cluster_time) return conn diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 60c15a9ec0..b89a990637 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -991,6 +991,7 @@ def _txn_read_preference(self) -> Optional[_ServerMode]: def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: if isinstance(self._server_session, _EmptyServerSession): old = self._server_session + assert self._client._topology is not None self._server_session = self._client._topology.get_server_session( logical_session_timeout_minutes ) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 534c06f705..2c283a046a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -755,7 +755,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - self._topology = None + self._topology: Optional[Topology] = None # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1034,6 +1034,7 @@ def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool] def _after_fork(self) -> None: """Resets topology in a child after successfully forking.""" + assert self._topology is not None self._init_background(self._topology._pid) # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() @@ -1193,6 +1194,7 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ + assert self._topology is not None return self._topology.description @property @@ -1206,6 +1208,7 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ + assert self._topology is not None description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1219,7 +1222,7 @@ def options(self) -> ClientOptions: """ return self._options - def eq_props(self): + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: return ( tuple(sorted(self._resolve_srv_info["seeds"])), self._options.replica_set_name, @@ -1240,7 +1243,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: if self._topology is None: - raise hash(self.eq_props()) + return hash(self.eq_props()) else: return hash(self._topology) @@ -1383,6 +1386,7 @@ def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[C def _send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[ClientSession] ) -> None: + assert self._topology is not None topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: @@ -1566,6 +1570,7 @@ def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ + assert self._topology is not None topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1588,6 +1593,7 @@ def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + assert self._topology is not None return self._topology.get_primary() # type: ignore[return-value] @property @@ -1601,6 +1607,7 @@ def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ + assert self._topology is not None return self._topology.get_secondaries() @property @@ -1611,6 +1618,7 @@ def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ + assert self._topology is not None return self._topology.get_arbiters() @property @@ -1703,10 +1711,12 @@ def _get_topology(self) -> Topology: if self._resolve_srv_info["is_srv"]: self._resolve_srv() self._init_background() + assert self._topology is not None self._topology.open() with self._lock: self._kill_cursors_executor.open() self._opened = True + assert self._topology is not None return self._topology @contextlib.contextmanager @@ -1809,6 +1819,7 @@ def _conn_from_server( # Thread safe: if the type is single it cannot change. # NOTE: We already opened the Topology when selecting a server so there's no need # to call _get_topology() again. + assert self._topology is not None single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single with self._checkout(server, session) as conn: if single: @@ -2148,6 +2159,7 @@ def _process_kill_cursors(self) -> None: """Process any pending kill cursors requests.""" address_to_cursor_ids = defaultdict(list) pinned_cursors = [] + assert self._topology is not None # Other threads or the GC may append to the queue concurrently. while True: @@ -2189,6 +2201,7 @@ def _process_periodic_tasks(self) -> None: """Process any pending kill cursors requests and maintain connection pool parameters. """ + assert self._topology is not None try: self._process_kill_cursors() self._topology.update_pool() @@ -2204,6 +2217,7 @@ def _return_server_session( """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): return None + assert self._topology is not None return self._topology.return_server_session(server_session) @contextlib.contextmanager @@ -2239,6 +2253,7 @@ def _tmp_session( yield None def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSession]) -> None: + assert self._topology is not None self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) @@ -2624,6 +2639,7 @@ def handle( self.completed_handshake, self.service_id, ) + assert self._client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) def __enter__(self) -> _MongoClientErrorHandler: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index cd78e26fea..bbd1bf73b1 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -1309,6 +1309,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect raise if handler: + assert handler.client._topology is not None handler.client._topology.receive_cluster_time(conn._cluster_time) return conn From 54664840c252ce66a52893a07c6a9938fc5ebc7b Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 18 Mar 2025 09:31:12 -0700 Subject: [PATCH 49/56] oops fix typing pt2 --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1cc7298c3a..f90220d573 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2653,7 +2653,7 @@ async def handle( self.completed_handshake, self.service_id, ) - assert self._client._topology is not None + assert self.client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) async def __aenter__(self) -> _MongoClientErrorHandler: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2c283a046a..5add4a8c5a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2639,7 +2639,7 @@ def handle( self.completed_handshake, self.service_id, ) - assert self._client._topology is not None + assert self.client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) def __enter__(self) -> _MongoClientErrorHandler: From 2900718828941b5182efccff33e5005c602b0fe6 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 21 Mar 2025 13:08:22 -0700 Subject: [PATCH 50/56] address comments --- pymongo/asynchronous/mongo_client.py | 12 +++--------- pymongo/asynchronous/uri_parser.py | 16 ++++++++++++++++ pymongo/srv_resolver.py | 22 ---------------------- pymongo/synchronous/mongo_client.py | 12 +++--------- pymongo/synchronous/uri_parser.py | 16 ++++++++++++++++ 5 files changed, 38 insertions(+), 40 deletions(-) delete mode 100644 pymongo/srv_resolver.py diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f90220d573..9e294d4578 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1234,20 +1234,14 @@ def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if self._topology is None: - return self.eq_props() == other.eq_props() - else: - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - if self._topology is None: - return hash(self.eq_props()) - else: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1264,7 +1258,7 @@ def option_repr(option: str, value: Any) -> str: # Host first... if self._topology is None: - options = ["host={self._host}", "port={self._port}"] + options = self._resolve_srv_info["seeds"] else: options = [ "host=%r" diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index 0cc69247eb..47c6d72031 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -1,3 +1,19 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" from __future__ import annotations from typing import Any, Optional diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py deleted file mode 100644 index 1414a86b05..0000000000 --- a/pymongo/srv_resolver.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Re-import of synchronous Uri Parser API for compatibility.""" -from __future__ import annotations - -from pymongo.synchronous.srv_resolver import __doc__ as original_doc -from pymongo.synchronous.srv_resolver import _SrvResolver, maybe_decode - -__doc__ = original_doc -__all__ = ["maybe_decode", "_SrvResolver"] diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5add4a8c5a..617d297fac 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1232,20 +1232,14 @@ def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - if self._topology is None: - return self.eq_props() == other.eq_props() - else: - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - if self._topology is None: - return hash(self.eq_props()) - else: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1262,7 +1256,7 @@ def option_repr(option: str, value: Any) -> str: # Host first... if self._topology is None: - options = ["host={self._host}", "port={self._port}"] + options = self._resolve_srv_info["seeds"] else: options = [ "host=%r" diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index 2810f0f123..52b59b8fe8 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -1,3 +1,19 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" from __future__ import annotations from typing import Any, Optional From 63676b6a188eb278abe8f872050b151a00e3dd66 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 21 Mar 2025 13:39:14 -0700 Subject: [PATCH 51/56] fix patch string --- test/test_client.py | 4 ++-- tools/synchro.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_client.py b/test/test_client.py index 8251d7fff8..ee199adf0b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -600,7 +600,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -622,7 +622,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ diff --git a/tools/synchro.py b/tools/synchro.py index e65270733e..d8760b83bc 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -127,6 +127,7 @@ "async_create_barrier": "create_barrier", "async_barrier_wait": "barrier_wait", "async_joinall": "joinall", + "pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts", } docstring_replacements: dict[tuple[str, str], str] = { From cd9bd927b6265af60e353ec6a9e18ad44c101931 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 21 Mar 2025 15:14:55 -0700 Subject: [PATCH 52/56] address comments pt1 --- pymongo/asynchronous/client_session.py | 1 - pymongo/asynchronous/mongo_client.py | 56 ++++++++++---------------- pymongo/asynchronous/pool.py | 1 - pymongo/synchronous/client_session.py | 1 - pymongo/synchronous/mongo_client.py | 56 ++++++++++---------------- pymongo/synchronous/pool.py | 1 - 6 files changed, 42 insertions(+), 74 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 1c3d0b6732..98dd6a4706 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -996,7 +996,6 @@ def _txn_read_preference(self) -> Optional[_ServerMode]: def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: if isinstance(self._server_session, _EmptyServerSession): old = self._server_session - assert self._client._topology is not None self._server_session = self._client._topology.get_server_session( logical_session_timeout_minutes ) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 9e294d4578..4cefc93812 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -757,7 +757,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - self._topology: Optional[Topology] = None + self._topology: Topology = None # type: ignore[assignment] # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1036,7 +1036,6 @@ def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[ def _after_fork(self) -> None: """Resets topology in a child after successfully forking.""" - assert self._topology is not None self._init_background(self._topology._pid) # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() @@ -1196,7 +1195,6 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ - assert self._topology is not None return self._topology.description @property @@ -1210,7 +1208,6 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ - assert self._topology is not None description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1384,7 +1381,6 @@ def _ensure_session( def _send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession] ) -> None: - assert self._topology is not None topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: @@ -1570,7 +1566,6 @@ async def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ - assert self._topology is not None topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1593,7 +1588,6 @@ async def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ - assert self._topology is not None return await self._topology.get_primary() # type: ignore[return-value] @property @@ -1607,7 +1601,6 @@ async def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 AsyncMongoClient gained this property in version 3.0. """ - assert self._topology is not None return await self._topology.get_secondaries() @property @@ -1618,7 +1611,6 @@ async def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ - assert self._topology is not None return await self._topology.get_arbiters() @property @@ -1677,25 +1669,26 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if self._topology is not None: - session_ids = self._topology.pop_all_sessions() - if session_ids: - await self._end_sessions(session_ids) - # Stop the periodic task thread and then send pending killCursor - # requests before closing the topology. - self._kill_cursors_executor.close() - await self._process_kill_cursors() - await self._topology.close() - if self._encrypter: - # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. - await self._encrypter.close() - self._closed = True - if not _IS_SYNC: - await asyncio.gather( - self._topology.cleanup_monitors(), # type: ignore[func-returns-value] - self._kill_cursors_executor.join(), # type: ignore[func-returns-value] - return_exceptions=True, - ) + if self._topology is None: + return + session_ids = self._topology.pop_all_sessions() + if session_ids: + await self._end_sessions(session_ids) + # Stop the periodic task thread and then send pending killCursor + # requests before closing the topology. + self._kill_cursors_executor.close() + await self._process_kill_cursors() + await self._topology.close() + if self._encrypter: + # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. + await self._encrypter.close() + self._closed = True + if not _IS_SYNC: + await asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.aclosing. @@ -1711,12 +1704,10 @@ async def _get_topology(self) -> Topology: if self._resolve_srv_info["is_srv"]: await self._resolve_srv() self._init_background() - assert self._topology is not None await self._topology.open() async with self._lock: self._kill_cursors_executor.open() self._opened = True - assert self._topology is not None return self._topology @contextlib.asynccontextmanager @@ -1819,7 +1810,6 @@ async def _conn_from_server( # Thread safe: if the type is single it cannot change. # NOTE: We already opened the Topology when selecting a server so there's no need # to call _get_topology() again. - assert self._topology is not None single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single async with self._checkout(server, session) as conn: if single: @@ -2159,7 +2149,6 @@ async def _process_kill_cursors(self) -> None: """Process any pending kill cursors requests.""" address_to_cursor_ids = defaultdict(list) pinned_cursors = [] - assert self._topology is not None # Other threads or the GC may append to the queue concurrently. while True: @@ -2201,7 +2190,6 @@ async def _process_periodic_tasks(self) -> None: """Process any pending kill cursors requests and maintain connection pool parameters. """ - assert self._topology is not None try: await self._process_kill_cursors() await self._topology.update_pool() @@ -2217,7 +2205,6 @@ def _return_server_session( """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): return None - assert self._topology is not None return self._topology.return_server_session(server_session) @contextlib.asynccontextmanager @@ -2255,7 +2242,6 @@ async def _tmp_session( async def _process_response( self, reply: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> None: - assert self._topology is not None await self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 215df35fb2..d06c528e78 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -1315,7 +1315,6 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A raise if handler: - assert handler.client._topology is not None await handler.client._topology.receive_cluster_time(conn._cluster_time) return conn diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index b89a990637..60c15a9ec0 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -991,7 +991,6 @@ def _txn_read_preference(self) -> Optional[_ServerMode]: def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: if isinstance(self._server_session, _EmptyServerSession): old = self._server_session - assert self._client._topology is not None self._server_session = self._client._topology.get_server_session( logical_session_timeout_minutes ) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 617d297fac..08e8b228cf 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -755,7 +755,7 @@ def __init__( raise TypeError(f"port must be an instance of int, not {type(port)}") self._host = host self._port = port - self._topology: Optional[Topology] = None + self._topology: Topology = None # type: ignore[assignment] # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -1034,7 +1034,6 @@ def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool] def _after_fork(self) -> None: """Resets topology in a child after successfully forking.""" - assert self._topology is not None self._init_background(self._topology._pid) # Reset the session pool to avoid duplicate sessions in the child process. self._topology._session_pool.reset() @@ -1194,7 +1193,6 @@ def topology_description(self) -> TopologyDescription: .. versionadded:: 4.0 """ - assert self._topology is not None return self._topology.description @property @@ -1208,7 +1206,6 @@ def nodes(self) -> FrozenSet[_Address]: to any servers, or a network partition causes it to lose connection to all servers. """ - assert self._topology is not None description = self._topology.description return frozenset(s.address for s in description.known_servers) @@ -1380,7 +1377,6 @@ def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[C def _send_cluster_time( self, command: MutableMapping[str, Any], session: Optional[ClientSession] ) -> None: - assert self._topology is not None topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: @@ -1564,7 +1560,6 @@ def address(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 """ - assert self._topology is not None topology_type = self._topology._description.topology_type if ( topology_type == TOPOLOGY_TYPE.Sharded @@ -1587,7 +1582,6 @@ def primary(self) -> Optional[tuple[str, int]]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ - assert self._topology is not None return self._topology.get_primary() # type: ignore[return-value] @property @@ -1601,7 +1595,6 @@ def secondaries(self) -> set[_Address]: .. versionadded:: 3.0 MongoClient gained this property in version 3.0. """ - assert self._topology is not None return self._topology.get_secondaries() @property @@ -1612,7 +1605,6 @@ def arbiters(self) -> set[_Address]: connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ - assert self._topology is not None return self._topology.get_arbiters() @property @@ -1671,25 +1663,26 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ - if self._topology is not None: - session_ids = self._topology.pop_all_sessions() - if session_ids: - self._end_sessions(session_ids) - # Stop the periodic task thread and then send pending killCursor - # requests before closing the topology. - self._kill_cursors_executor.close() - self._process_kill_cursors() - self._topology.close() - if self._encrypter: - # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. - self._encrypter.close() - self._closed = True - if not _IS_SYNC: - asyncio.gather( - self._topology.cleanup_monitors(), # type: ignore[func-returns-value] - self._kill_cursors_executor.join(), # type: ignore[func-returns-value] - return_exceptions=True, - ) + if self._topology is None: + return + session_ids = self._topology.pop_all_sessions() + if session_ids: + self._end_sessions(session_ids) + # Stop the periodic task thread and then send pending killCursor + # requests before closing the topology. + self._kill_cursors_executor.close() + self._process_kill_cursors() + self._topology.close() + if self._encrypter: + # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. + self._encrypter.close() + self._closed = True + if not _IS_SYNC: + asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.closing. @@ -1705,12 +1698,10 @@ def _get_topology(self) -> Topology: if self._resolve_srv_info["is_srv"]: self._resolve_srv() self._init_background() - assert self._topology is not None self._topology.open() with self._lock: self._kill_cursors_executor.open() self._opened = True - assert self._topology is not None return self._topology @contextlib.contextmanager @@ -1813,7 +1804,6 @@ def _conn_from_server( # Thread safe: if the type is single it cannot change. # NOTE: We already opened the Topology when selecting a server so there's no need # to call _get_topology() again. - assert self._topology is not None single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single with self._checkout(server, session) as conn: if single: @@ -2153,7 +2143,6 @@ def _process_kill_cursors(self) -> None: """Process any pending kill cursors requests.""" address_to_cursor_ids = defaultdict(list) pinned_cursors = [] - assert self._topology is not None # Other threads or the GC may append to the queue concurrently. while True: @@ -2195,7 +2184,6 @@ def _process_periodic_tasks(self) -> None: """Process any pending kill cursors requests and maintain connection pool parameters. """ - assert self._topology is not None try: self._process_kill_cursors() self._topology.update_pool() @@ -2211,7 +2199,6 @@ def _return_server_session( """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): return None - assert self._topology is not None return self._topology.return_server_session(server_session) @contextlib.contextmanager @@ -2247,7 +2234,6 @@ def _tmp_session( yield None def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSession]) -> None: - assert self._topology is not None self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index bbd1bf73b1..cd78e26fea 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -1309,7 +1309,6 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect raise if handler: - assert handler.client._topology is not None handler.client._topology.receive_cluster_time(conn._cluster_time) return conn From a7c090dbc131d664c674383b25a441869b60c237 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 21 Mar 2025 15:55:16 -0700 Subject: [PATCH 53/56] add test for repr and change changelog --- doc/changelog.rst | 2 ++ pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- test/asynchronous/test_client.py | 9 +++++++++ test/test_client.py | 9 +++++++++ 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index b172da6b8e..df623c7e6e 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including: - Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to :class:`~pymongo.encryption_options.AutoEncryptionOpts`. - Support for $lookup in CSFLE and QE supported on MongoDB 8.1+. +- AsyncMongoClient no longer performs dns resolution for "mongodb+srv://" connection string in the constructor. + To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected. Issues Resolved ............... diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4cefc93812..8686309746 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1255,7 +1255,7 @@ def option_repr(option: str, value: Any) -> str: # Host first... if self._topology is None: - options = self._resolve_srv_info["seeds"] + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] else: options = [ "host=%r" diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 08e8b228cf..eca8f2cdb8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1253,7 +1253,7 @@ def option_repr(option: str, value: Any) -> str: # Host first... if self._topology is None: - options = self._resolve_srv_info["seeds"] + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] else: options = [ "host=%r" diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 37b40f8d22..208b5b1b04 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -933,6 +933,15 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) + async def test_repr_srv_host(self): + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + await client.aconnect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + await client.close() + async def test_getters(self): await async_wait_until( lambda: async_client_context.nodes == self.client.nodes, "find all nodes" diff --git a/test/test_client.py b/test/test_client.py index ee199adf0b..9ef5fbb348 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -908,6 +908,15 @@ def test_repr(self): with eval(the_repr) as client_two: self.assertEqual(client_two, client) + def test_repr_srv_host(self): + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + client._connect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + client.close() + def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") From 93bc3c9bb34322fc9c1c79e9f0c2b05d052ebb44 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 21 Mar 2025 16:21:57 -0700 Subject: [PATCH 54/56] fix test --- test/asynchronous/test_client.py | 2 +- test/test_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 208b5b1b04..7f70b84825 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -934,7 +934,7 @@ async def test_repr(self): self.assertEqual(client_two, client) async def test_repr_srv_host(self): - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) # before srv resolution self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) await client.aconnect() diff --git a/test/test_client.py b/test/test_client.py index 9ef5fbb348..cd4ceb3299 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -909,7 +909,7 @@ def test_repr(self): self.assertEqual(client_two, client) def test_repr_srv_host(self): - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) # before srv resolution self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) client._connect() From afd82f40cbd0a44192a2e637adf8dd9c0c74d9e3 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 24 Mar 2025 10:49:26 -0700 Subject: [PATCH 55/56] Update doc/changelog.rst Co-authored-by: Noah Stapp --- doc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index df623c7e6e..8256a5736b 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -9,7 +9,7 @@ PyMongo 4.12 brings a number of changes including: - Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to :class:`~pymongo.encryption_options.AutoEncryptionOpts`. - Support for $lookup in CSFLE and QE supported on MongoDB 8.1+. -- AsyncMongoClient no longer performs dns resolution for "mongodb+srv://" connection string in the constructor. +- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation. To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected. Issues Resolved From 99a5c8aca5919149985821e94127fbdd19af9548 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 24 Mar 2025 15:52:32 -0400 Subject: [PATCH 56/56] Address review --- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/asynchronous/srv_resolver.py | 4 +++- pymongo/synchronous/mongo_client.py | 4 ++-- pymongo/synchronous/srv_resolver.py | 4 +++- test/test_default_exports.py | 4 ++-- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 8686309746..754b8325ed 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -59,7 +59,6 @@ cast, ) -import pymongo.uri_parser_shared from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp from pymongo import _csot, common, helpers_shared, periodic_executor @@ -121,6 +120,7 @@ _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -786,7 +786,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = pymongo.uri_parser_shared._validate_uri( + res = _validate_uri( entity, port, validate=True, diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index b5adf6c920..8b811e5dc2 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -61,7 +61,9 @@ async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: if hasattr(asyncresolver, "resolve"): # dnspython >= 2 return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] - raise ConfigurationError("Upgrade to dnspython version >= 2.0") + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections." + ) _INVALID_HOST_MSG = ( diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index eca8f2cdb8..1cedbfe1e2 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -58,7 +58,6 @@ cast, ) -import pymongo.uri_parser_shared from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp from pymongo import _csot, common, helpers_shared, periodic_executor @@ -120,6 +119,7 @@ _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -784,7 +784,7 @@ def __init__( # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - res = pymongo.uri_parser_shared._validate_uri( + res = _validate_uri( entity, port, validate=True, diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index 486c7e6522..1b36efd1c9 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -61,7 +61,9 @@ def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: if hasattr(asyncresolver, "resolve"): # dnspython >= 2 return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] - raise ConfigurationError("Upgrade to dnspython version >= 2.0") + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections." + ) _INVALID_HOST_MSG = ( diff --git a/test/test_default_exports.py b/test/test_default_exports.py index 9035414c75..d9301d2223 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -69,7 +69,6 @@ def test_bson(self): def test_pymongo_imports(self): import pymongo - from pymongo.asynchronous.uri_parser import parse_uri from pymongo.auth import MECHANISMS from pymongo.auth_oidc import ( OIDCCallback, @@ -199,9 +198,10 @@ def test_pymongo_imports(self): from pymongo.server_api import ServerApi, ServerApiVersion from pymongo.server_description import ServerDescription from pymongo.topology_description import TopologyDescription - from pymongo.uri_parser_shared import ( + from pymongo.uri_parser import ( parse_host, parse_ipv6_literal_host, + parse_uri, parse_userinfo, split_hosts, split_options,