diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 06c6b68ac9..c5681e345a 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -32,6 +32,12 @@ DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf")) +def reset_all() -> None: + TIMEOUT.set(None) + RTT.set(0.0) + DEADLINE.set(float("inf")) + + def get_timeout() -> Optional[float]: return TIMEOUT.get(None) diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 323debdce2..ed369a2b21 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -23,6 +23,7 @@ import weakref from typing import Any, Optional +from pymongo import _csot from pymongo._asyncio_task import create_task from pymongo.lock import _create_lock @@ -93,6 +94,8 @@ def skip_sleep(self) -> None: self._skip_sleep = True async def _run(self) -> None: + # The CSOT contextvars must be cleared inside the executor task before execution begins + _csot.reset_all() while not self._stopped: if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined] raise asyncio.CancelledError diff --git a/test/asynchronous/test_async_contextvars_reset.py b/test/asynchronous/test_async_contextvars_reset.py new file mode 100644 index 0000000000..9b0e2dc4dc --- /dev/null +++ b/test/asynchronous/test_async_contextvars_reset.py @@ -0,0 +1,43 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that AsyncPeriodicExecutors do not copy ContextVars from their parents.""" +from __future__ import annotations + +import asyncio +import sys +from test.asynchronous.utils import async_get_pool +from test.utils_shared import delay, one + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest + + +class TestAsyncContextVarsReset(AsyncIntegrationTest): + async def test_context_vars_are_reset_in_executor(self): + if sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Task.get_context (added in Python 3.11)") + + client = self.simple_client() + + await client.db.test.insert_one({"x": 1}) + for server in client._topology._servers.values(): + for context in [ + c + for c in server._monitor._executor._task.get_context() + if c.name in ["TIMEOUT", "RTT", "DEADLINE"] + ]: + self.assertIn(context.get(), [None, float("inf"), 0.0]) + await client.db.test.delete_many({}) diff --git a/tools/synchro.py b/tools/synchro.py index 1fa8c674a5..906bfd00da 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -185,6 +185,7 @@ def async_only_test(f: str) -> bool: "test_concurrency.py", "test_async_cancellation.py", "test_async_loop_safety.py", + "test_async_contextvars_reset.py", ]