From 1944927eaceac9a1fd9c2a615eef0239a411f8d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE?= Date: Sat, 16 May 2020 21:47:14 +0800 Subject: [PATCH] Fix contextvars not propagated from fixture to test --- pytest_asyncio/plugin.py | 39 +++++- ...st_async_gen_fixtures_within_context_37.py | 116 ++++++++++++++++++ tests/conftest.py | 3 + 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 tests/async_fixtures/test_async_gen_fixtures_within_context_37.py diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index c0b65da2..cb8e0f91 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -18,6 +18,11 @@ def transfer_markers(*args, **kwargs): # noqa except ImportError: from inspect import isasyncgenfunction +try: + import contextvars +except ImportError: + contextvars = None + def _is_coroutine(obj): """Check to see if an object is really an asyncio coroutine.""" @@ -48,6 +53,14 @@ def pytest_pycollect_makeitem(collector, name, obj): return list(collector._genfunctions(name, obj)) +current_context = None + + +def apply_context(context): + for var in context: + var.set(context[var]) + + @pytest.hookimpl(hookwrapper=True) def pytest_fixture_setup(fixturedef, request): """Adjust the event loop policy when an event loop is produced.""" @@ -65,6 +78,10 @@ def pytest_fixture_setup(fixturedef, request): fixturedef.addfinalizer(lambda: policy.set_event_loop(old_loop)) return + if current_context: + # Apply the current context + apply_context(current_context) + if isasyncgenfunction(fixturedef.func): # This is an async generator function. Wrap it accordingly. generator = fixturedef.func @@ -83,7 +100,9 @@ def wrapper(*args, **kwargs): async def setup(): res = await gen_obj.__anext__() - return res + # return the current context + # that is maybe modified by async gen_obj + return res, contextvars and contextvars.copy_context() def finalizer(): """Yield again, to finalize.""" @@ -99,7 +118,15 @@ async def async_finalizer(): asyncio.get_event_loop().run_until_complete(async_finalizer()) request.addfinalizer(finalizer) - return asyncio.get_event_loop().run_until_complete(setup()) + + res, context = asyncio.get_event_loop().run_until_complete(setup()) + if context: + # Store the current context + global current_context + + current_context = context + + return res fixturedef.func = wrapper elif inspect.iscoroutinefunction(fixturedef.func): @@ -122,6 +149,11 @@ def pytest_pyfunc_call(pyfuncitem): Run asyncio marked test functions in an event loop instead of a normal function call. """ + global current_context + if current_context: + # Apply the current context + apply_context(current_context) + if 'asyncio' in pyfuncitem.keywords: if getattr(pyfuncitem.obj, 'is_hypothesis_test', False): pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( @@ -134,6 +166,8 @@ def pytest_pyfunc_call(pyfuncitem): _loop=pyfuncitem.funcargs['event_loop'] ) yield + # Cleanup the current context + current_context = None def wrap_in_sync(func, _loop): @@ -150,6 +184,7 @@ def inner(**kwargs): if 'no current event loop' not in str(exc): raise loop = _loop + task = asyncio.ensure_future(coro, loop=loop) try: loop.run_until_complete(task) diff --git a/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py b/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py new file mode 100644 index 00000000..8b7845a4 --- /dev/null +++ b/tests/async_fixtures/test_async_gen_fixtures_within_context_37.py @@ -0,0 +1,116 @@ +import unittest.mock + +import pytest + +START = object() +END = object() +RETVAL = object() + + +@pytest.fixture(scope="module") +def mock(): + return unittest.mock.Mock(return_value=RETVAL) + + +@pytest.fixture +def var(): + contextvars = pytest.importorskip("contextvars") + + return contextvars.ContextVar("var_1") + + +@pytest.fixture +async def async_gen_fixture_within_context(mock, var): + var.set(1) + try: + yield mock(START) + except Exception as e: + mock(e) + else: + mock(END) + + +@pytest.mark.asyncio +async def test_async_gen_fixture_within_context( + async_gen_fixture_within_context, mock, var +): + assert var.get() == 1 + assert mock.called + assert mock.call_args_list[-1] == unittest.mock.call(START) + assert async_gen_fixture_within_context is RETVAL + + +@pytest.mark.asyncio +async def test_async_gen_fixture_within_context_finalized(mock, var): + with pytest.raises(LookupError): + var.get() + + try: + assert mock.called + assert mock.call_args_list[-1] == unittest.mock.call(END) + finally: + mock.reset_mock() + + +@pytest.fixture +async def async_gen_fixture_1(var): + var.set(1) + yield + + +@pytest.fixture +async def async_gen_fixture_2(async_gen_fixture_1, var): + assert var.get() == 1 + var.set(2) + yield + + +@pytest.mark.asyncio +async def test_context_overwrited_by_another_async_gen_fixture( + async_gen_fixture_2, var +): + assert var.get() == 2 + + +@pytest.fixture +async def async_fixture_within_context(async_gen_fixture_1, var): + assert var.get() == 1 + + +@pytest.fixture +def fixture_within_context(async_gen_fixture_1, var): + assert var.get() == 1 + + +@pytest.mark.asyncio +async def test_context_propagated_from_gen_fixture_to_normal_fixture( + fixture_within_context, async_fixture_within_context +): + pass + + +@pytest.fixture +def var_2(): + contextvars = pytest.importorskip("contextvars") + + return contextvars.ContextVar("var_2") + + +@pytest.fixture +async def async_gen_fixture_set_var_1(var): + var.set(1) + yield + + +@pytest.fixture +async def async_gen_fixture_set_var_2(var_2): + var_2.set(2) + yield + + +@pytest.mark.asyncio +async def test_context_modified_by_different_fixtures( + async_gen_fixture_set_var_1, async_gen_fixture_set_var_2, var, var_2 +): + assert var.get() == 1 + assert var_2.get() == 2 diff --git a/tests/conftest.py b/tests/conftest.py index cc2ec163..f0fb78eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ collect_ignore.append("async_fixtures/test_async_gen_fixtures_36.py") collect_ignore.append("async_fixtures/test_nested_36.py") +if sys.version_info[:2] < (3, 7): + collect_ignore.append("async_fixtures/test_async_gen_fixtures_within_context_37.py") + @pytest.fixture def dependent_fixture(event_loop):