diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 2fdc5f4e..0f9fc3fa 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -18,6 +18,11 @@ def transfer_markers(*args, **kwargs): # noqa except ImportError: from inspect import isasyncgenfunction +try: + import contextvars +except ImportError: + contextvars = None + def _is_coroutine(obj): """Check to see if an object is really an asyncio coroutine.""" @@ -48,6 +53,12 @@ def pytest_pycollect_makeitem(collector, name, obj): return list(collector._genfunctions(name, obj)) +current_context = None + + +def apply_context(context): + for var in context: + var.set(context[var]) class FixtureStripper: """Include additional Fixture, and then strip them""" REQUEST = "request" @@ -91,6 +102,10 @@ def pytest_fixture_setup(fixturedef, request): policy.set_event_loop(loop) return + if current_context: + # Apply the current context + apply_context(current_context) + if isasyncgenfunction(fixturedef.func): # This is an async generator function. Wrap it accordingly. generator = fixturedef.func @@ -108,7 +123,9 @@ def wrapper(*args, **kwargs): async def setup(): res = await gen_obj.__anext__() - return res + # return the current context + # that is maybe modified by async gen_obj + return res, contextvars and contextvars.copy_context() def finalizer(): """Yield again, to finalize.""" @@ -124,7 +141,15 @@ async def async_finalizer(): loop.run_until_complete(async_finalizer()) request.addfinalizer(finalizer) - return loop.run_until_complete(setup()) + + res, context = asyncio.get_event_loop().run_until_complete(setup()) + if context: + # Store the current context + global current_context + + current_context = context + + return res fixturedef.func = wrapper elif inspect.iscoroutinefunction(fixturedef.func): @@ -152,6 +177,11 @@ def pytest_pyfunc_call(pyfuncitem): Run asyncio marked test functions in an event loop instead of a normal function call. """ + global current_context + if current_context: + # Apply the current context + apply_context(current_context) + if 'asyncio' in pyfuncitem.keywords: if getattr(pyfuncitem.obj, 'is_hypothesis_test', False): pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( @@ -164,6 +194,8 @@ def pytest_pyfunc_call(pyfuncitem): _loop=pyfuncitem.funcargs['event_loop'] ) yield + # Cleanup the current context + current_context = None def wrap_in_sync(func, _loop): diff --git a/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py b/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py new file mode 100644 index 00000000..8b7845a4 --- /dev/null +++ b/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py @@ -0,0 +1,116 @@ +import unittest.mock + +import pytest + +START = object() +END = object() +RETVAL = object() + + +@pytest.fixture(scope="module") +def mock(): + return unittest.mock.Mock(return_value=RETVAL) + + +@pytest.fixture +def var(): + contextvars = pytest.importorskip("contextvars") + + return contextvars.ContextVar("var_1") + + +@pytest.fixture +async def async_gen_fixture_within_context(mock, var): + var.set(1) + try: + yield mock(START) + except Exception as e: + mock(e) + else: + mock(END) + + +@pytest.mark.asyncio +async def test_async_gen_fixture_within_context( + async_gen_fixture_within_context, mock, var +): + assert var.get() == 1 + assert mock.called + assert mock.call_args_list[-1] == unittest.mock.call(START) + assert async_gen_fixture_within_context is RETVAL + + +@pytest.mark.asyncio +async def test_async_gen_fixture_within_context_finalized(mock, var): + with pytest.raises(LookupError): + var.get() + + try: + assert mock.called + assert mock.call_args_list[-1] == unittest.mock.call(END) + finally: + mock.reset_mock() + + +@pytest.fixture +async def async_gen_fixture_1(var): + var.set(1) + yield + + +@pytest.fixture +async def async_gen_fixture_2(async_gen_fixture_1, var): + assert var.get() == 1 + var.set(2) + yield + + +@pytest.mark.asyncio +async def test_context_overwrited_by_another_async_gen_fixture( + async_gen_fixture_2, var +): + assert var.get() == 2 + + +@pytest.fixture +async def async_fixture_within_context(async_gen_fixture_1, var): + assert var.get() == 1 + + +@pytest.fixture +def fixture_within_context(async_gen_fixture_1, var): + assert var.get() == 1 + + +@pytest.mark.asyncio +async def test_context_propagated_from_gen_fixture_to_normal_fixture( + fixture_within_context, async_fixture_within_context +): + pass + + +@pytest.fixture +def var_2(): + contextvars = pytest.importorskip("contextvars") + + return contextvars.ContextVar("var_2") + + +@pytest.fixture +async def async_gen_fixture_set_var_1(var): + var.set(1) + yield + + +@pytest.fixture +async def async_gen_fixture_set_var_2(var_2): + var_2.set(2) + yield + + +@pytest.mark.asyncio +async def test_context_modified_by_different_fixtures( + async_gen_fixture_set_var_1, async_gen_fixture_set_var_2, var, var_2 +): + assert var.get() == 1 + assert var_2.get() == 2 diff --git a/tests/conftest.py b/tests/conftest.py index cc2ec163..f0fb78eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ collect_ignore.append("async_fixtures/test_async_gen_fixtures_36.py") collect_ignore.append("async_fixtures/test_nested_36.py") +if sys.version_info[:2] < (3, 7): + collect_ignore.append("async_fixtures/test_async_gen_fixtures_within_context_37.py") + @pytest.fixture def dependent_fixture(event_loop):