diff --git a/doc/changelog.rst b/doc/changelog.rst index 12991eeb29..5683fcaaca 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -9,6 +9,8 @@ PyMongo 4.12 brings a number of changes including: - Support for configuring DEK cache lifetime via the ``key_expiration_ms`` argument to :class:`~pymongo.encryption_options.AutoEncryptionOpts`. - Support for $lookup in CSFLE and QE supported on MongoDB 8.1+. +- AsyncMongoClient no longer performs DNS resolution for "mongodb+srv://" connection strings on creation. + To avoid blocking the asyncio loop, the resolution is now deferred until the client is first connected. - Added index hinting support to the :meth:`~pymongo.asynchronous.collection.AsyncCollection.distinct` and :meth:`~pymongo.collection.Collection.distinct` commands. diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 3582bec9ab..68de42db84 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -87,7 +87,7 @@ from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index ecd57a1886..754b8325ed 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -44,6 +44,7 @@ AsyncContextManager, AsyncGenerator, Callable, + Collection, Coroutine, FrozenSet, Generic, @@ -60,8 +61,8 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser -from pymongo.asynchronous import client_session, database +from pymongo import _csot, common, helpers_shared, periodic_executor +from pymongo.asynchronous import client_session, database, uri_parser from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -113,11 +114,14 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -128,6 +132,7 @@ from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager + from pymongo.asynchronous.encryption import _Encrypter from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern @@ -750,6 +755,9 @@ def __init__( port = self.PORT if not isinstance(port, int): raise TypeError(f"port must be an instance of int, not {type(port)}") + self._host = host + self._port = port + self._topology: Topology = None # type: ignore[assignment] # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -760,8 +768,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() + is_srv = False username = None password = None dbase = None @@ -769,29 +779,22 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) - res = uri_parser.parse_uri( + res = _validate_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) + is_srv = entity.startswith(SRV_SCHEME) seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password @@ -799,7 +802,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -820,80 +823,179 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _async_create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, + self._encrypter: Optional[_Encrypter] = None + + self._resolve_srv_info.update( + { + "is_srv": is_srv, + "username": username, + "password": password, + "dbase": dbase, + "seeds": seeds, + "fqdn": fqdn, + "srv_service_name": srv_service_name, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False - self._init_background() + if not is_srv: + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + async def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = await uri_parser._parse_srv( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + opts = self._normalize_and_validate_options(opts, seeds) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) + self._options = ClientOptions( + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC + ) + + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: + self._event_listeners = self._options.pool_options._event_listeners + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) if self._options.auto_encryption_opts: from pymongo.asynchronous.encryption import _Encrypter self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - AsyncMongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts async def aconnect(self) -> None: """Explicitly connect to MongoDB asynchronously instead of on the first operation.""" @@ -901,6 +1003,10 @@ async def aconnect(self) -> None: def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + AsyncMongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1115,16 +1221,24 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1140,13 +1254,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if self._topology is None: + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] + else: + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1552,6 +1669,8 @@ async def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ + if self._topology is None: + return session_ids = self._topology.pop_all_sessions() if session_ids: await self._end_sessions(session_ids) @@ -1582,6 +1701,9 @@ async def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: + if self._resolve_srv_info["is_srv"]: + await self._resolve_srv() + self._init_background() await self._topology.open() async with self._lock: self._kill_cursors_executor.open() @@ -2511,6 +2633,7 @@ async def handle( self.completed_handshake, self.service_id, ) + assert self.client._topology is not None await self.client._topology.handle_error(self.server_address, err_ctx) async def __aenter__(self) -> _MongoClientErrorHandler: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 5cb42f4d46..1b0799e1c4 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -25,6 +25,7 @@ from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum +from pymongo.asynchronous.srv_resolver import _SrvResolver from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _async_create_lock @@ -33,7 +34,6 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext @@ -395,7 +395,7 @@ async def _run(self) -> None: # Don't poll right after creation, wait 60 seconds first if time.monotonic() < self._startup_time + common.MIN_SRV_RESCAN_INTERVAL: return - seedlist = self._get_seedlist() + seedlist = await self._get_seedlist() if seedlist: self._seedlist = seedlist try: @@ -404,7 +404,7 @@ async def _run(self) -> None: # Topology was garbage-collected. await self.close() - def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: + async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: """Poll SRV records for a seedlist. Returns a list of ServerDescriptions. @@ -415,7 +415,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._settings.pool_options.connect_timeout, self._settings.srv_service_name, ) - seedlist, ttl = resolver.get_hosts_and_min_ttl() + seedlist, ttl = await resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py new file mode 100644 index 0000000000..8b811e5dc2 --- /dev/null +++ b/pymongo/asynchronous/srv_resolver.py @@ -0,0 +1,160 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Support for resolving hosts and options from mongodb+srv:// URIs.""" +from __future__ import annotations + +import ipaddress +import random +from typing import TYPE_CHECKING, Any, Optional, Union + +from pymongo.common import CONNECT_TIMEOUT +from pymongo.errors import ConfigurationError + +if TYPE_CHECKING: + from dns import resolver + +_IS_SYNC = False + + +def _have_dnspython() -> bool: + try: + import dns # noqa: F401 + + return True + except ImportError: + return False + + +# dnspython can return bytes or str from various parts +# of its API depending on version. We always want str. +def maybe_decode(text: Union[str, bytes]) -> str: + if isinstance(text, bytes): + return text.decode() + return text + + +# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. +async def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: + if _IS_SYNC: + from dns import resolver + + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns import asyncresolver + + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return await asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use AsyncMongoClient with mongodb+srv:// connections." + ) + + +_INVALID_HOST_MSG = ( + "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " + "Did you mean to use 'mongodb://'?" +) + + +class _SrvResolver: + def __init__( + self, + fqdn: str, + connect_timeout: Optional[float], + srv_service_name: str, + srv_max_hosts: int = 0, + ): + self.__fqdn = fqdn + self.__srv = srv_service_name + self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT + self.__srv_max_hosts = srv_max_hosts or 0 + # Validate the fully qualified domain name. + try: + ipaddress.ip_address(fqdn) + raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) + except ValueError: + pass + + try: + self.__plist = self.__fqdn.split(".")[1:] + except Exception: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None + self.__slen = len(self.__plist) + if self.__slen < 2: + raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) + + async def get_options(self) -> Optional[str]: + from dns import resolver + + try: + results = await _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) + except (resolver.NoAnswer, resolver.NXDOMAIN): + # No TXT records + return None + except Exception as exc: + raise ConfigurationError(str(exc)) from None + if len(results) > 1: + raise ConfigurationError("Only one TXT record is supported") + return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") # type: ignore[attr-defined] + + async def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: + try: + results = await _resolve( + "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout + ) + except Exception as exc: + if not encapsulate_errors: + # Raise the original error. + raise + # Else, raise all errors as ConfigurationError. + raise ConfigurationError(str(exc)) from None + return results + + async def _get_srv_response_and_hosts( + self, encapsulate_errors: bool + ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: + results = await self._resolve_uri(encapsulate_errors) + + # Construct address tuples + nodes = [ + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) # type: ignore[attr-defined] + for res in results + ] + + # Validate hosts + for node in nodes: + try: + nlist = node[0].lower().split(".")[1:][-self.__slen :] + except Exception: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_max_hosts: + nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) + return results, nodes + + async def get_hosts(self) -> list[tuple[str, Any]]: + _, nodes = await self._get_srv_response_and_hosts(True) + return nodes + + async def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: + results, nodes = await self._get_srv_response_and_hosts(False) + rrset = results.rrset + ttl = rrset.ttl if rrset else 0 + return nodes, ttl diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py new file mode 100644 index 0000000000..47c6d72031 --- /dev/null +++ b/pymongo/asynchronous/uri_parser.py @@ -0,0 +1,188 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" +from __future__ import annotations + +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.asynchronous.srv_resolver import _SrvResolver +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME_LEN, + _check_options, + _validate_uri, + split_hosts, + split_options, +) + +_IS_SYNC = False + + +async def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + await _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +async def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = await dns_resolver.get_hosts() + dns_options = await dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index 02fcc98e46..4cb94cba30 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -32,7 +32,7 @@ from bson import int64 from pymongo.common import validate_is_mapping from pymongo.errors import ConfigurationError -from pymongo.uri_parser import _parse_kms_tls_options +from pymongo.uri_parser_shared import _parse_kms_tls_options if TYPE_CHECKING: from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index ebffc7d74c..38c28de91e 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -86,7 +86,7 @@ _raise_connection_failure, ) from pymongo.typings import _DocumentType, _DocumentTypeArg -from pymongo.uri_parser import parse_host +from pymongo.uri_parser_shared import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 79b6cf6ed9..1cedbfe1e2 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -42,6 +42,7 @@ TYPE_CHECKING, Any, Callable, + Collection, ContextManager, FrozenSet, Generator, @@ -59,7 +60,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -96,7 +97,7 @@ from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database +from pymongo.synchronous import client_session, database, uri_parser from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession @@ -112,11 +113,14 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.uri_parser import ( +from pymongo.uri_parser_shared import ( + SRV_SCHEME, _check_options, _handle_option_deprecations, _handle_security_options, _normalize_options, + _validate_uri, + split_hosts, ) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern @@ -130,6 +134,7 @@ from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager + from pymongo.synchronous.encryption import _Encrypter from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -748,6 +753,9 @@ def __init__( port = self.PORT if not isinstance(port, int): raise TypeError(f"port must be an instance of int, not {type(port)}") + self._host = host + self._port = port + self._topology: Topology = None # type: ignore[assignment] # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -758,8 +766,10 @@ def __init__( # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts} seeds = set() + is_srv = False username = None password = None dbase = None @@ -767,29 +777,22 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") - if len([h for h in host if "/" in h]) > 1: + if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") - for entity in host: + for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names if "/" in entity: - # Determine connection timeout from kwargs. - timeout = keyword_opts.get("connecttimeoutms") - if timeout is not None: - timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout - ) - res = uri_parser.parse_uri( + res = _validate_uri( entity, port, validate=True, warn=True, normalize=False, - connect_timeout=timeout, - srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, ) + is_srv = entity.startswith(SRV_SCHEME) seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password @@ -797,7 +800,7 @@ def __init__( opts = res["options"] fqdn = res["fqdn"] else: - seeds.update(uri_parser.split_hosts(entity, port)) + seeds.update(split_hosts(entity, self._port)) if not seeds: raise ConfigurationError("need to specify at least one host") @@ -818,80 +821,179 @@ def __init__( keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect - # Handle deprecated options in kwarg options. - keyword_opts = _handle_option_deprecations(keyword_opts) - # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary( - dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) - ) - - # Override connection string options with kwarg options. - opts.update(keyword_opts) + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) if srv_service_name is None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") - # Handle security-option conflicts in combined options. - opts = _handle_security_options(opts) - # Normalize combined options. - opts = _normalize_options(opts) - _check_options(seeds, opts) + opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) + self._options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase self._lock = _create_lock() self._kill_cursors_queue: list = [] - self._event_listeners = options.pool_options._event_listeners - super().__init__( - options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern, + self._encrypter: Optional[_Encrypter] = None + + self._resolve_srv_info.update( + { + "is_srv": is_srv, + "username": username, + "password": password, + "dbase": dbase, + "seeds": seeds, + "fqdn": fqdn, + "srv_service_name": srv_service_name, + "pool_class": pool_class, + "monitor_class": monitor_class, + "condition_class": condition_class, + } ) - self._topology_settings = TopologySettings( - seeds=seeds, - replica_set_name=options.replica_set_name, - pool_class=pool_class, - pool_options=options.pool_options, - monitor_class=monitor_class, - condition_class=condition_class, - local_threshold_ms=options.local_threshold_ms, - server_selection_timeout=options.server_selection_timeout, - server_selector=options.server_selector, - heartbeat_frequency=options.heartbeat_frequency, - fqdn=fqdn, - direct_connection=options.direct_connection, - load_balanced=options.load_balanced, - srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts, - server_monitoring_mode=options.server_monitoring_mode, + super().__init__( + self._options.codec_options, + self._options.read_preference, + self._options.write_concern, + self._options.read_concern, ) + if not is_srv: + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._opened = False self._closed = False - self._init_background() + if not is_srv: + self._init_background() if _IS_SYNC and connect: self._get_topology() # type: ignore[unused-coroutine] - self._encrypter = None + def _resolve_srv(self) -> None: + keyword_opts = self._resolve_srv_info["keyword_opts"] + seeds = set() + opts = common._CaseInsensitiveDictionary() + srv_service_name = keyword_opts.get("srvservicename") + srv_max_hosts = keyword_opts.get("srvmaxhosts") + for entity in self._host: + # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' + # it must be a URI, + # https://en.wikipedia.org/wiki/Hostname#Restrictions_on_valid_host_names + if "/" in entity: + # Determine connection timeout from kwargs. + timeout = keyword_opts.get("connecttimeoutms") + if timeout is not None: + timeout = common.validate_timeout_or_none_or_zero( + keyword_opts.cased_key("connecttimeoutms"), timeout + ) + res = uri_parser._parse_srv( + entity, + self._port, + validate=True, + warn=True, + normalize=False, + connect_timeout=timeout, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + ) + seeds.update(res["nodelist"]) + opts = res["options"] + else: + seeds.update(split_hosts(entity, self._port)) + + if not seeds: + raise ConfigurationError("need to specify at least one host") + + for hostname in [node[0] for node in seeds]: + if _detect_external_db(hostname): + break + + # Add options with named keyword arguments to the parsed kwarg options. + tz_aware = keyword_opts["tz_aware"] + connect = keyword_opts["connect"] + if tz_aware is None: + tz_aware = opts.get("tz_aware", False) + if connect is None: + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect + + opts = self._validate_kwargs_and_update_opts(keyword_opts, opts) + + if srv_service_name is None: + srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) + + srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + opts = self._normalize_and_validate_options(opts, seeds) + + # Username and password passed as kwargs override user info in URI. + username = opts.get("username", self._resolve_srv_info["username"]) + password = opts.get("password", self._resolve_srv_info["password"]) + self._options = ClientOptions( + username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC + ) + + self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + + def _init_based_on_options( + self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + ) -> None: + self._event_listeners = self._options.pool_options._event_listeners + self._topology_settings = TopologySettings( + seeds=seeds, + replica_set_name=self._options.replica_set_name, + pool_class=self._resolve_srv_info["pool_class"], + pool_options=self._options.pool_options, + monitor_class=self._resolve_srv_info["monitor_class"], + condition_class=self._resolve_srv_info["condition_class"], + local_threshold_ms=self._options.local_threshold_ms, + server_selection_timeout=self._options.server_selection_timeout, + server_selector=self._options.server_selector, + heartbeat_frequency=self._options.heartbeat_frequency, + fqdn=self._resolve_srv_info["fqdn"], + direct_connection=self._options.direct_connection, + load_balanced=self._options.load_balanced, + srv_service_name=srv_service_name, + srv_max_hosts=srv_max_hosts, + server_monitoring_mode=self._options.server_monitoring_mode, + ) if self._options.auto_encryption_opts: from pymongo.synchronous.encryption import _Encrypter self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) self._timeout = self._options.timeout - if _HAS_REGISTER_AT_FORK: - # Add this client to the list of weakly referenced items. - # This will be used later if we fork. - MongoClient._clients[self._topology._topology_id] = self + def _normalize_and_validate_options( + self, opts: common._CaseInsensitiveDictionary, seeds: set[tuple[str, int | None]] + ) -> common._CaseInsensitiveDictionary: + # Handle security-option conflicts in combined options. + opts = _handle_security_options(opts) + # Normalize combined options. + opts = _normalize_options(opts) + _check_options(seeds, opts) + return opts + + def _validate_kwargs_and_update_opts( + self, + keyword_opts: common._CaseInsensitiveDictionary, + opts: common._CaseInsensitiveDictionary, + ) -> common._CaseInsensitiveDictionary: + # Handle deprecated options in kwarg options. + keyword_opts = _handle_option_deprecations(keyword_opts) + # Validate kwarg options. + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) + # Override connection string options with kwarg options. + opts.update(keyword_opts) + return opts def _connect(self) -> None: """Explicitly connect to MongoDB synchronously instead of on the first operation.""" @@ -899,6 +1001,10 @@ def _connect(self) -> None: def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) + if _HAS_REGISTER_AT_FORK: + # Add this client to the list of weakly referenced items. + # This will be used later if we fork. + MongoClient._clients[self._topology._topology_id] = self # Seed the topology with the old one's pid so we can detect clients # that are opened before a fork and used after. self._topology._pid = old_pid @@ -1113,16 +1219,24 @@ def options(self) -> ClientOptions: """ return self._options + def eq_props(self) -> tuple[tuple[_Address, ...], Optional[str], Optional[str], str]: + return ( + tuple(sorted(self._resolve_srv_info["seeds"])), + self._options.replica_set_name, + self._resolve_srv_info["fqdn"], + self._resolve_srv_info["srv_service_name"], + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): - return self._topology == other._topology + return self.eq_props() == other.eq_props() return NotImplemented def __ne__(self, other: Any) -> bool: return not self == other def __hash__(self) -> int: - return hash(self._topology) + return hash(self.eq_props()) def _repr_helper(self) -> str: def option_repr(option: str, value: Any) -> str: @@ -1138,13 +1252,16 @@ def option_repr(option: str, value: Any) -> str: return f"{option}={value!r}" # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds + if self._topology is None: + options = [f"host='mongodb+srv://{self._resolve_srv_info['fqdn']}'"] + else: + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] ] - ] # ... then everything in self._constructor_args... options.extend( option_repr(key, self._options._options[key]) for key in self._constructor_args @@ -1546,6 +1663,8 @@ def close(self) -> None: .. versionchanged:: 3.6 End all server sessions created by this client. """ + if self._topology is None: + return session_ids = self._topology.pop_all_sessions() if session_ids: self._end_sessions(session_ids) @@ -1576,6 +1695,9 @@ def _get_topology(self) -> Topology: launches the connection process in the background. """ if not self._opened: + if self._resolve_srv_info["is_srv"]: + self._resolve_srv() + self._init_background() self._topology.open() with self._lock: self._kill_cursors_executor.open() @@ -2497,6 +2619,7 @@ def handle( self.completed_handshake, self.service_id, ) + assert self.client._topology is not None self.client._topology.handle_error(self.server_address, err_ctx) def __enter__(self) -> _MongoClientErrorHandler: diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 5b45ed9a4d..a2b76c4e8a 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -33,7 +33,7 @@ from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription -from pymongo.srv_resolver import _SrvResolver +from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext diff --git a/pymongo/srv_resolver.py b/pymongo/synchronous/srv_resolver.py similarity index 88% rename from pymongo/srv_resolver.py rename to pymongo/synchronous/srv_resolver.py index 5be6cb98db..1b36efd1c9 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from dns import resolver +_IS_SYNC = True + def _have_dnspython() -> bool: try: @@ -45,13 +47,23 @@ def maybe_decode(text: Union[str, bytes]) -> str: # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: - from dns import resolver + if _IS_SYNC: + from dns import resolver - if hasattr(resolver, "resolve"): - # dnspython >= 2 - return resolver.resolve(*args, **kwargs) - # dnspython 1.X - return resolver.query(*args, **kwargs) + if hasattr(resolver, "resolve"): + # dnspython >= 2 + return resolver.resolve(*args, **kwargs) + # dnspython 1.X + return resolver.query(*args, **kwargs) + else: + from dns import asyncresolver + + if hasattr(asyncresolver, "resolve"): + # dnspython >= 2 + return asyncresolver.resolve(*args, **kwargs) # type:ignore[return-value] + raise ConfigurationError( + "Upgrade to dnspython version >= 2.0 to use MongoClient with mongodb+srv:// connections." + ) _INVALID_HOST_MSG = ( diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py new file mode 100644 index 0000000000..52b59b8fe8 --- /dev/null +++ b/pymongo/synchronous/uri_parser.py @@ -0,0 +1,188 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" +from __future__ import annotations + +from typing import Any, Optional +from urllib.parse import unquote_plus + +from pymongo.common import SRV_SERVICE_NAME, _CaseInsensitiveDictionary +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.synchronous.srv_resolver import _SrvResolver +from pymongo.uri_parser_shared import ( + _ALLOWED_TXT_OPTS, + DEFAULT_PORT, + SCHEME, + SCHEME_LEN, + SRV_SCHEME_LEN, + _check_options, + _validate_uri, + split_hosts, + split_options, +) + +_IS_SYNC = True + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + result = _validate_uri(uri, default_port, validate, warn, normalize, srv_max_hosts) + result.update( + _parse_srv( + uri, + default_port, + validate, + warn, + normalize, + connect_timeout, + srv_service_name, + srv_max_hosts, + ) + ) + return result + + +def _parse_srv( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + else: + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, _ = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + _, _, hosts = host_part.rpartition("@") + else: + hosts = host_part + + hosts = unquote_plus(hosts) + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + nodes = split_hosts(hosts, default_port=None) + fqdn, port = nodes[0] + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "options": options, + } diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index ee7ca9c205..fe253b9bbf 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -13,627 +13,32 @@ # permissions and limitations under the License. -"""Tools to parse and validate a MongoDB URI. - -.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs. -""" +"""Re-import of synchronous URI Parser API for compatibility.""" from __future__ import annotations -import re import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.client_options import _parse_ssl_options -from pymongo.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.typings import _Address - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit(): - # Special case check for mistakes like "mongodb://localhost:27017 ". - if all(c.isspace() or c.isdigit() for c in port): - for c in port: - if c.isspace(): - raise ValueError(f"Port contains whitespace character: {c!r}") - - # A non-digit port indicates that the URI is invalid, likely because the password - # or username were not escaped. - raise ValueError( - "Port contains non-digit characters. Hint: username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - if int(port) > 65535 or int(port) <= 0: - raise ValueError("Port must be an integer between 0 and 65535") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # https://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list)") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts +from pymongo.errors import InvalidURI +from pymongo.synchronous.uri_parser import * # noqa: F403 +from pymongo.synchronous.uri_parser import __doc__ as original_doc +from pymongo.uri_parser_shared import * # noqa: F403 + +__doc__ = original_doc +__all__ = [ # noqa: F405 + "parse_userinfo", + "parse_ipv6_literal_host", + "parse_host", + "validate_options", + "split_options", + "split_hosts", + "parse_uri", +] if __name__ == "__main__": import pprint try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + pprint.pprint(parse_uri(sys.argv[1])) # noqa: F405, T203 except InvalidURI as exc: print(exc) # noqa: T201 sys.exit(0) diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py new file mode 100644 index 0000000000..e7ba4c9fb5 --- /dev/null +++ b/pymongo/uri_parser_shared.py @@ -0,0 +1,549 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI. + +.. seealso:: This module is compatible with both the synchronous and asynchronous PyMongo APIs. +""" +from __future__ import annotations + +import re +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sized, + Union, + cast, +) +from urllib.parse import unquote_plus + +from pymongo.asynchronous.srv_resolver import _have_dnspython +from pymongo.client_options import _parse_ssl_options +from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.typings import _Address + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext + +SCHEME = "mongodb://" +SCHEME_LEN = len(SCHEME) +SRV_SCHEME = "mongodb+srv://" +SRV_SCHEME_LEN = len(SRV_SCHEME) +DEFAULT_PORT = 27017 + + +def _unquoted_percent(s: str) -> bool: + """Check for unescaped percent signs. + + :param s: A string. `s` can have things like '%25', '%2525', + and '%E2%85%A8' but cannot have unquoted percent like '%foo'. + """ + for i in range(len(s)): + if s[i] == "%": + sub = s[i : i + 3] + # If unquoting yields the same string this means there was an + # unquoted %. + if unquote_plus(sub) == sub: + return True + return False + + +def parse_userinfo(userinfo: str) -> tuple[str, str]: + """Validates the format of user information in a MongoDB URI. + Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", + "]", "@") as per RFC 3986 must be escaped. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :param userinfo: A string of the form : + """ + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + + user, _, passwd = userinfo.partition(":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username") + + return unquote_plus(user), unquote_plus(passwd) + + +def parse_ipv6_literal_host( + entity: str, default_port: Optional[int] +) -> tuple[str, Optional[Union[str, int]]]: + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :param entity: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + :param default_port: The port number to use when one wasn't + specified in entity. + """ + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." + ) + i = entity.find("]:") + if i == -1: + return entity[1:-1], default_port + return entity[1:i], entity[i + 2 :] + + +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :param entity: A host or host:port string where host could be a + hostname or IP address. + :param default_port: The port number to use when one wasn't + specified in entity. + """ + host = entity + port: Optional[Union[str, int]] = default_port + if entity[0] == "[": + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.endswith(".sock"): + return entity, default_port + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) + if isinstance(port, str): + if not port.isdigit(): + # Special case check for mistakes like "mongodb://localhost:27017 ". + if all(c.isspace() or c.isdigit() for c in port): + for c in port: + if c.isspace(): + raise ValueError(f"Port contains whitespace character: {c!r}") + + # A non-digit port indicates that the URI is invalid, likely because the password + # or username were not escaped. + raise ValueError( + "Port contains non-digit characters. Hint: username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + if int(port) > 65535 or int(port) <= 0: + raise ValueError("Port must be an integer between 0 and 65535") + port = int(port) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # https://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +# Options whose values are implicitly determined by tlsInsecure. +_IMPLICIT_TLSINSECURE_OPTS = { + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", +} + + +def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: + """Helper method for split_options which creates the options dict. + Also handles the creation of a list for the URI tag_sets/ + readpreferencetags portion, and the use of a unicode options string. + """ + options = _CaseInsensitiveDictionary() + for uriopt in opts.split(delim): + key, value = uriopt.split("=") + if key.lower() == "readpreferencetags": + options.setdefault(key, []).append(value) + else: + if key in options: + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) + if key.lower() == "authmechanismproperties": + val = value + else: + val = unquote_plus(value) + options[key] = val + + return options + + +def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Raise appropriate errors when conflicting TLS options are present in + the options dictionary. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Implicitly defined options must not be explicitly specified. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + if opt in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) + + # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") + if tlsallowinvalidcerts is not None: + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) + if tlsallowinvalidcerts is True: + options["tlsdisableocspendpointcheck"] = True + + # Handle co-occurence of CRL and OCSP-related options. + tlscrlfile = options.get("tlscrlfile") + if tlscrlfile is not None: + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): + if options.get(opt) is True: + err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." + raise InvalidURI(err_msg % (opt,)) + + if "ssl" in options and "tls" in options: + + def truth_value(val: Any) -> Any: + if val in ("true", "false"): + return val == "true" + if isinstance(val, bool): + return val + return val + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) + + return options + + +def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Issue appropriate warnings when deprecated options are present in the + options dictionary. Removes deprecated option key, value pairs if the + options dictionary is found to also have the renamed option. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + for optname in list(options): + if optname in URI_OPTIONS_DEPRECATION_MAP: + mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] + if mode == "renamed": + newoptname = message + if newoptname in options: + warn_msg = "Deprecated option '%s' ignored in favor of '%s'." + warnings.warn( + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) + options.pop(optname) + continue + warn_msg = "Option '%s' is deprecated, use '%s' instead." + warnings.warn( + warn_msg % (options.cased_key(optname), newoptname), + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": + warn_msg = "Option '%s' is deprecated. %s." + warnings.warn( + warn_msg % (options.cased_key(optname), message), + DeprecationWarning, + stacklevel=2, + ) + + return options + + +def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Normalizes option names in the options dictionary by converting them to + their internally-used names. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Expand the tlsInsecure option. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + # Implicit options are logically the same as tlsInsecure. + options[opt] = tlsinsecure + + for optname in list(options): + intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) + if intname is not None: + options[intname] = options.pop(optname) + + return options + + +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. If warn is + False then errors will be thrown for invalid options, otherwise they will + be ignored and a warning will be issued. + + :param opts: A dict of MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise invalid options will + cause errors. + """ + return get_validated_options(opts, warn) + + +def split_options( + opts: str, validate: bool = True, warn: bool = False, normalize: bool = True +) -> MutableMapping[str, Any]: + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. + + :param opt: A string representing MongoDB URI options. + :param validate: If ``True`` (the default), validate and normalize all + options. + :param warn: If ``False`` (the default), suppress all warnings raised + during validation of options. + :param normalize: If ``True`` (the default), renames all options to their + internally-used names. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators") + elif and_idx >= 0: + options = _parse_options(opts, "&") + elif semi_idx >= 0: + options = _parse_options(opts, ";") + elif opts.find("=") != -1: + options = _parse_options(opts, None) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs") from None + + options = _handle_security_options(options) + + options = _handle_option_deprecations(options) + + if normalize: + options = _normalize_options(options) + + if validate: + options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") + + return options + + +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :param hosts: A string of the form host1[:port],host2[:port],... + :param default_port: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(","): + if not entity: + raise ConfigurationError("Empty host (or extra comma in host list)") + port = default_port + # Unix socket entities don't have ports + if entity.endswith(".sock"): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +# Prohibited characters in database name. DB names also can't have ".", but for +# backward-compat we allow "db.collection" in URI. +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") + +_ALLOWED_TXT_OPTS = frozenset( + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) + + +def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: + # Ensure directConnection was not True if there are multiple seeds. + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") + + if options.get("loadbalanced"): + if len(nodes) > 1: + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: + """Parse KMS TLS connection options.""" + if not kms_tls_options: + return {} + if not isinstance(kms_tls_options, dict): + raise TypeError("kms_tls_options must be a dict") + contexts = {} + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): + raise TypeError(f'kms_tls_options["{provider}"] must be a dict') + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) + opts = _handle_security_options(opts) + opts = _normalize_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) + ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) + if ssl_context is None: + raise ConfigurationError("TLS is required for KMS providers") + if allow_invalid_hostnames: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableCertificateRevocationCheck", + ]: + if n in opts: + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") + contexts[provider] = ssl_context + return contexts + + +def _validate_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } diff --git a/test/__init__.py b/test/__init__.py index 307780271d..8e362de5ad 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -32,7 +32,7 @@ import warnings from asyncio import iscoroutinefunction -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri try: import ipaddress diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index f03fcf4eeb..b3f65e5d3c 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -32,7 +32,7 @@ import warnings from asyncio import iscoroutinefunction -from pymongo.uri_parser import parse_uri +from pymongo.asynchronous.uri_parser import parse_uri try: import ipaddress @@ -1027,7 +1027,7 @@ async def _unmanaged_async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] @@ -1058,7 +1058,7 @@ async def _async_mongo_client( auth_mech = kwargs.get("authMechanism", "") if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": # Only add the default username or password if one is not provided. - res = parse_uri(uri) + res = await parse_uri(uri) if ( not res["username"] and not res["password"] diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 98e00e9385..7b021e8b44 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -47,7 +47,7 @@ from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f529dcce14..7f70b84825 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -512,13 +512,13 @@ async def test_uri_option_precedence(self): async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.asynchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.asynchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.asynchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -607,7 +607,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -629,7 +629,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts") async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ @@ -933,6 +933,15 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) + async def test_repr_srv_host(self): + client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + await client.aconnect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + await client.close() + async def test_getters(self): await async_wait_until( lambda: async_client_context.nodes == self.client.nodes, "find all nodes" @@ -1911,28 +1920,37 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + await client.aconnect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + await client.close() async def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + await client.aconnect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + await client.aconnect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index b3de2c5a4d..fa62b25dd1 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -54,6 +54,7 @@ from pymongo import common, monitoring from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -66,7 +67,6 @@ from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.uri_parser import parse_uri _IS_SYNC = False @@ -81,7 +81,7 @@ async def create_mock_topology(uri, monitor_class=DummyMonitor): - parsed_uri = parse_uri(uri) + parsed_uri = await parse_uri(uri) replica_set_name = None direct_connection = None load_balanced = None diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index a622062fec..d0e801e123 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -31,9 +31,10 @@ ) from test.utils_shared import async_wait_until +from pymongo.asynchronous.uri_parser import parse_uri from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser_shared import split_hosts _IS_SYNC = False @@ -109,7 +110,7 @@ async def run_test(self): hosts = frozenset(split_hosts(",".join(hosts))) if seeds or num_seeds: - result = parse_uri(uri, validate=True) + result = await parse_uri(uri, validate=True) if seeds is not None: self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: @@ -161,7 +162,7 @@ async def run_test(self): # and re-run these assertions. else: try: - parse_uri(uri) + await parse_uri(uri) except (ConfigurationError, ValueError): pass else: @@ -185,35 +186,24 @@ def create_tests(cls): class TestParsingErrors(AsyncPyMongoTestCase): async def test_invalid_host(self): - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb is not", - self.simple_client, - "mongodb+srv://mongodb", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb.com is not", - self.simple_client, - "mongodb+srv://mongodb.com", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://127.0.0.1", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://[::1]", - ) + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): + client = self.simple_client("mongodb+srv://mongodb") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): + client = self.simple_client("mongodb+srv://mongodb.com") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + await client.aconnect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + await client.aconnect() class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): async def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + await client.aconnect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index bf7807eb97..3dcd21ef1d 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -28,8 +28,8 @@ import pymongo from pymongo import common +from pymongo.asynchronous.srv_resolver import _have_dnspython from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _have_dnspython _IS_SYNC = False @@ -54,14 +54,16 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval - def mock_get_hosts_and_min_ttl(resolver, *args): + async def mock_get_hosts_and_min_ttl(resolver, *args): assert self.old_dns_resolver_response is not None - nodes, ttl = self.old_dns_resolver_response(resolver) + nodes, ttl = await self.old_dns_resolver_response(resolver) if self.nodelist_callback is not None: nodes = self.nodelist_callback() if self.ttl_time is not None: @@ -74,14 +76,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -134,7 +136,10 @@ async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return ( + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) return False await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -144,7 +149,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index a7660f2f67..9738694d85 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -32,7 +32,7 @@ from pymongo import MongoClient from pymongo.errors import OperationFailure -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri pytestmark = pytest.mark.auth_aws diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 6dc36dc8a4..0c8431a1e8 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -49,7 +49,7 @@ OIDCCallbackResult, _get_authenticator, ) -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" diff --git a/test/helpers.py b/test/helpers.py index 627be182b5..12c55ade1b 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -47,7 +47,7 @@ from pymongo import common, message from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/test_client.py b/test/test_client.py index e445fa632a..cd4ceb3299 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -505,13 +505,13 @@ def test_uri_option_precedence(self): def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.srv_resolver import _resolve + from pymongo.synchronous.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver + pymongo.synchronous.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve + pymongo.synchronous.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -600,7 +600,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -622,7 +622,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ @@ -908,6 +908,15 @@ def test_repr(self): with eval(the_repr) as client_two: self.assertEqual(client_two, client) + def test_repr_srv_host(self): + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", connect=False) + # before srv resolution + self.assertIn("host='mongodb+srv://test1.test.build.10gen.cc'", repr(client)) + client._connect() + # after srv resolution + self.assertIn("host=['localhost.test.build.10gen.cc:", repr(client)) + client.close() + def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") @@ -1868,28 +1877,37 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + client._connect() self.assertEqual(client._topology_settings.srv_service_name, "customname") + client.close() def test_srv_max_hosts_kwarg(self): client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + client._connect() self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 00021310c9..07720473ca 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -65,8 +65,8 @@ from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext +from pymongo.synchronous.uri_parser import parse_uri from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo.uri_parser import parse_uri _IS_SYNC = True diff --git a/test/test_dns.py b/test/test_dns.py index 71326ae49e..0290eb16d9 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -33,7 +33,8 @@ from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.uri_parser import parse_uri, split_hosts +from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser_shared import split_hosts _IS_SYNC = True @@ -183,35 +184,24 @@ def create_tests(cls): class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb is not", - self.simple_client, - "mongodb+srv://mongodb", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: mongodb.com is not", - self.simple_client, - "mongodb+srv://mongodb.com", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://127.0.0.1", - ) - self.assertRaisesRegex( - ConfigurationError, - "Invalid URI host: an IP address is not", - self.simple_client, - "mongodb+srv://[::1]", - ) + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb is not"): + client = self.simple_client("mongodb+srv://mongodb") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: mongodb.com is not"): + client = self.simple_client("mongodb+srv://mongodb.com") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://127.0.0.1") + client._connect() + with self.assertRaisesRegex(ConfigurationError, "Invalid URI host: an IP address is not"): + client = self.simple_client("mongodb+srv://[::1]") + client._connect() class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + client._connect() self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 6812465074..df802acb43 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -29,7 +29,7 @@ import pymongo from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _have_dnspython +from pymongo.synchronous.srv_resolver import _have_dnspython _IS_SYNC = True @@ -54,7 +54,9 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl + ) if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -74,14 +76,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -134,7 +136,10 @@ def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAI def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 + return ( + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count + >= 1 + ) return False wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -144,7 +149,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index f95717e95f..0baefa0c3a 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -28,8 +28,8 @@ from bson.binary import JAVA_LEGACY from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.uri_parser import ( - parse_uri, +from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser_shared import ( parse_userinfo, split_hosts, split_options, diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 29cde7e078..aeb0be94b5 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -29,7 +29,7 @@ from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _have_snappy -from pymongo.uri_parser import parse_uri +from pymongo.synchronous.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") diff --git a/tools/synchro.py b/tools/synchro.py index e65270733e..d8760b83bc 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -127,6 +127,7 @@ "async_create_barrier": "create_barrier", "async_barrier_wait": "barrier_wait", "async_joinall": "joinall", + "pymongo.asynchronous.srv_resolver._SrvResolver.get_hosts": "pymongo.synchronous.srv_resolver._SrvResolver.get_hosts", } docstring_replacements: dict[tuple[str, str], str] = {