diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 72755263c9..6390bcdd67 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -70,6 +70,7 @@ from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions +from pymongo.driver_info import DriverInfo from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -1040,6 +1041,23 @@ async def target() -> bool: self._kill_cursors_executor = executor self._opened = False + def append_metadata(self, driver_info: DriverInfo) -> None: + """ + Appends the given metadata to existing driver metadata. + """ + metadata = self._options.pool_options.metadata + for k, v in driver_info._asdict().items(): + if v is None: + continue + if k in metadata: + metadata[k] = f"{metadata[k]}|{v}" + elif k in metadata["driver"]: + metadata["driver"][k] = "{}|{}".format( + metadata["driver"][k], + v, + ) + self._options.pool_options._set_metadata(metadata) + def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[bool]: return self._options.load_balanced and not (session and session.in_transaction) diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py index a2e309cc56..33cd97978f 100644 --- a/pymongo/pool_options.py +++ b/pymongo/pool_options.py @@ -522,3 +522,6 @@ def server_api(self) -> Optional[ServerApi]: def load_balanced(self) -> Optional[bool]: """True if this Pool is configured in load balanced mode.""" return self.__load_balanced + + def _set_metadata(self, new_data: dict[str, Any]) -> None: + self.__metadata = new_data diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 99a517e5c1..f36ae491d6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -62,6 +62,7 @@ from bson.timestamp import Timestamp from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.client_options import ClientOptions +from pymongo.driver_info import DriverInfo from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -1040,6 +1041,23 @@ def target() -> bool: self._kill_cursors_executor = executor self._opened = False + def append_metadata(self, driver_info: DriverInfo) -> None: + """ + Appends the given metadata to existing driver metadata. + """ + metadata = self._options.pool_options.metadata + for k, v in driver_info._asdict().items(): + if v is None: + continue + if k in metadata: + metadata[k] = f"{metadata[k]}|{v}" + elif k in metadata["driver"]: + metadata["driver"][k] = "{}|{}".format( + metadata["driver"][k], + v, + ) + self._options.pool_options._set_metadata(metadata) + def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: return self._options.load_balanced and not (session and session.in_transaction) diff --git a/test/mockupdb/test_client_metadata.py b/test/mockupdb/test_client_metadata.py new file mode 100644 index 0000000000..27eb0fdeea --- /dev/null +++ b/test/mockupdb/test_client_metadata.py @@ -0,0 +1,210 @@ +# Copyright 2013-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import time +import unittest +from test.utils_shared import CMAPListener +from typing import Any, Optional + +import pytest + +from pymongo import MongoClient +from pymongo.driver_info import DriverInfo +from pymongo.monitoring import ConnectionClosedEvent + +try: + from mockupdb import MockupDB, OpMsgReply + + _HAVE_MOCKUPDB = True +except ImportError: + _HAVE_MOCKUPDB = False + +pytestmark = pytest.mark.mockupdb + + +def _get_handshake_driver_info(request): + assert "client" in request + return request["client"] + + +class TestClientMetadataProse(unittest.TestCase): + def setUp(self): + self.server = MockupDB() + # there are two handshake requests, i believe one is from the monitor, and the other is from the client + self.handshake_req: Optional[dict] = None + + def respond(r): + if "ismaster" in r: + # then this is a handshake request + self.handshake_req = r + return r.reply(OpMsgReply(minWireVersion=0, maxWireVersion=13)) + + self.server.autoresponds(respond) + self.server.run() + self.addCleanup(self.server.stop) + + def send_ping_and_get_metadata( + self, client: MongoClient, is_handshake: bool + ) -> tuple[str, Optional[str], Optional[str], dict[str, Any]]: + # reset + if is_handshake: + self.handshake_req: Optional[dict] = None + client.admin.command("ping") + metadata = _get_handshake_driver_info(self.handshake_req) + driver_metadata = metadata["driver"] + name, version, platform = ( + driver_metadata["name"], + driver_metadata["version"], + metadata["platform"], + ) + return name, version, platform, metadata + + def check_metadata_added( + self, + client: MongoClient, + add_name: str, + add_version: Optional[str], + add_platform: Optional[str], + ) -> None: + # send initial metadata + name, version, platform, metadata = self.send_ping_and_get_metadata(client, True) + time.sleep(0.005) + + # add new metadata + client.append_metadata(DriverInfo(add_name, add_version, add_platform)) + new_name, new_version, new_platform, new_metadata = self.send_ping_and_get_metadata( + client, True + ) + print("IN SEND PING AND GET METADATA") + print(name, version, platform) + print(metadata) + print(new_name, new_version, new_platform) + print(new_metadata) + self.assertEqual(new_name, f"{name}|{add_name}" if add_name is not None else name) + self.assertEqual( + new_version, + f"{version}|{add_version}" if add_version is not None else version, + ) + self.assertEqual( + new_platform, + f"{platform}|{add_platform}" if add_platform is not None else platform, + ) + + metadata.pop("driver") + metadata.pop("platform") + new_metadata.pop("driver") + new_metadata.pop("platform") + self.assertEqual(metadata, new_metadata) + + def test_append_metadata(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + driver=DriverInfo("library", "1.2", "Library Platform"), + ) + self.check_metadata_added(client, "framework", "2.0", "Framework Platform") + client.close() + + def test_append_metadata_platform_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + driver=DriverInfo("library", "1.2", "Library Platform"), + ) + self.check_metadata_added(client, "framework", "2.0", None) + client.close() + + def test_append_metadata_version_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + driver=DriverInfo("library", "1.2", "Library Platform"), + ) + self.check_metadata_added(client, "framework", None, "Framework Platform") + client.close() + + def test_append_metadata_platform_version_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + driver=DriverInfo("library", "1.2", "Library Platform"), + ) + self.check_metadata_added(client, "framework", None, None) + client.close() + + def test_multiple_successive_metadata_updates(self): + client = MongoClient( + "mongodb://" + self.server.address_string, maxIdleTimeMS=1, connect=False + ) + client.append_metadata(DriverInfo("library", "1.2", "Library Platform")) + self.check_metadata_added(client, "framework", "2.0", "Framework Platform") + client.close() + + def test_multiple_successive_metadata_updates_platform_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + ) + client.append_metadata(DriverInfo("library", "1.2", "Library Platform")) + self.check_metadata_added(client, "framework", "2.0", None) + client.close() + + def test_multiple_successive_metadata_updates_version_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + ) + client.append_metadata(DriverInfo("library", "1.2", "Library Platform")) + self.check_metadata_added(client, "framework", None, "Framework Platform") + client.close() + + def test_multiple_successive_metadata_updates_platform_version_none(self): + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + ) + client.append_metadata(DriverInfo("library", "1.2", "Library Platform")) + self.check_metadata_added(client, "framework", None, None) + client.close() + + def test_doesnt_update_established_connections(self): + listener = CMAPListener() + client = MongoClient( + "mongodb://" + self.server.address_string, + maxIdleTimeMS=1, + driver=DriverInfo("library", "1.2", "Library Platform"), + event_listeners=[listener], + ) + + # send initial metadata + name, version, platform, metadata = self.send_ping_and_get_metadata(client, True) + self.assertIsNotNone(name) + self.assertIsNotNone(version) + self.assertIsNotNone(platform) + + # add data + add_name, add_version, add_platform = "framework", "2.0", "Framework Platform" + client.append_metadata(DriverInfo(add_name, add_version, add_platform)) + # check new data isn't sent + self.handshake_req: Optional[dict] = None + client.admin.command("ping") + self.assertIsNone(self.handshake_req) + self.assertEqual(listener.event_count(ConnectionClosedEvent), 0) + + client.close() + + +if __name__ == "__main__": + unittest.main()