Skip to content

Commit b2744c1

Browse files
lisroachfried
authored andcommitted
[3.8] bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269) (#17304)
(cherry picked from commit 046442d) Co-authored-by: Jason Fried <fried@fb.com>
1 parent 9458c5c commit b2744c1

File tree

5 files changed

+103
-42
lines changed

5 files changed

+103
-42
lines changed

Doc/library/unittest.mock.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ object::
873873
exception,
874874
- if ``side_effect`` is an iterable, the async function will return the
875875
next value of the iterable, however, if the sequence of result is
876-
exhausted, ``StopIteration`` is raised immediately,
876+
exhausted, ``StopAsyncIteration`` is raised immediately,
877877
- if ``side_effect`` is not defined, the async function will return the
878878
value defined by ``return_value``, hence, by default, the async function
879879
returns a new :class:`AsyncMock` object.

Lib/unittest/mock.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,8 @@ def _increment_mock_call(self, /, *args, **kwargs):
11251125
_new_parent = _new_parent._mock_new_parent
11261126

11271127
def _execute_mock_call(self, /, *args, **kwargs):
1128-
# seperate from _increment_mock_call so that awaited functions are
1129-
# executed seperately from their call
1128+
# separate from _increment_mock_call so that awaited functions are
1129+
# executed separately from their call, also AsyncMock overrides this method
11301130

11311131
effect = self.side_effect
11321132
if effect is not None:
@@ -2122,29 +2122,45 @@ def __init__(self, /, *args, **kwargs):
21222122
code_mock.co_flags = inspect.CO_COROUTINE
21232123
self.__dict__['__code__'] = code_mock
21242124

2125-
async def _mock_call(self, /, *args, **kwargs):
2126-
try:
2127-
result = super()._mock_call(*args, **kwargs)
2128-
except (BaseException, StopIteration) as e:
2129-
side_effect = self.side_effect
2130-
if side_effect is not None and not callable(side_effect):
2131-
raise
2132-
return await _raise(e)
2125+
async def _execute_mock_call(self, /, *args, **kwargs):
2126+
# This is nearly just like super(), except for sepcial handling
2127+
# of coroutines
21332128

21342129
_call = self.call_args
2130+
self.await_count += 1
2131+
self.await_args = _call
2132+
self.await_args_list.append(_call)
21352133

2136-
async def proxy():
2137-
try:
2138-
if inspect.isawaitable(result):
2139-
return await result
2140-
else:
2141-
return result
2142-
finally:
2143-
self.await_count += 1
2144-
self.await_args = _call
2145-
self.await_args_list.append(_call)
2134+
effect = self.side_effect
2135+
if effect is not None:
2136+
if _is_exception(effect):
2137+
raise effect
2138+
elif not _callable(effect):
2139+
try:
2140+
result = next(effect)
2141+
except StopIteration:
2142+
# It is impossible to propogate a StopIteration
2143+
# through coroutines because of PEP 479
2144+
raise StopAsyncIteration
2145+
if _is_exception(result):
2146+
raise result
2147+
elif asyncio.iscoroutinefunction(effect):
2148+
result = await effect(*args, **kwargs)
2149+
else:
2150+
result = effect(*args, **kwargs)
21462151

2147-
return await proxy()
2152+
if result is not DEFAULT:
2153+
return result
2154+
2155+
if self._mock_return_value is not DEFAULT:
2156+
return self.return_value
2157+
2158+
if self._mock_wraps is not None:
2159+
if asyncio.iscoroutinefunction(self._mock_wraps):
2160+
return await self._mock_wraps(*args, **kwargs)
2161+
return self._mock_wraps(*args, **kwargs)
2162+
2163+
return self.return_value
21482164

21492165
def assert_awaited(self):
21502166
"""
@@ -2852,10 +2868,6 @@ def seal(mock):
28522868
seal(m)
28532869

28542870

2855-
async def _raise(exception):
2856-
raise exception
2857-
2858-
28592871
class _AsyncIterator:
28602872
"""
28612873
Wraps an iterator in an asynchronous iterator.

Lib/unittest/test/testmock/testasync.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -353,42 +353,84 @@ def test_magicmock_lambda_spec(self):
353353
self.assertIsInstance(cm, MagicMock)
354354

355355

356-
class AsyncArguments(unittest.TestCase):
357-
def test_add_return_value(self):
356+
class AsyncArguments(unittest.IsolatedAsyncioTestCase):
357+
async def test_add_return_value(self):
358358
async def addition(self, var):
359359
return var + 1
360360

361361
mock = AsyncMock(addition, return_value=10)
362-
output = asyncio.run(mock(5))
362+
output = await mock(5)
363363

364364
self.assertEqual(output, 10)
365365

366-
def test_add_side_effect_exception(self):
366+
async def test_add_side_effect_exception(self):
367367
async def addition(var):
368368
return var + 1
369369
mock = AsyncMock(addition, side_effect=Exception('err'))
370370
with self.assertRaises(Exception):
371-
asyncio.run(mock(5))
371+
await mock(5)
372372

373-
def test_add_side_effect_function(self):
373+
async def test_add_side_effect_function(self):
374374
async def addition(var):
375375
return var + 1
376376
mock = AsyncMock(side_effect=addition)
377-
result = asyncio.run(mock(5))
377+
result = await mock(5)
378378
self.assertEqual(result, 6)
379379

380-
def test_add_side_effect_iterable(self):
380+
async def test_add_side_effect_iterable(self):
381381
vals = [1, 2, 3]
382382
mock = AsyncMock(side_effect=vals)
383383
for item in vals:
384-
self.assertEqual(item, asyncio.run(mock()))
385-
386-
with self.assertRaises(RuntimeError) as e:
387-
asyncio.run(mock())
388-
self.assertEqual(
389-
e.exception,
390-
RuntimeError('coroutine raised StopIteration')
391-
)
384+
self.assertEqual(item, await mock())
385+
386+
with self.assertRaises(StopAsyncIteration) as e:
387+
await mock()
388+
389+
async def test_return_value_AsyncMock(self):
390+
value = AsyncMock(return_value=10)
391+
mock = AsyncMock(return_value=value)
392+
result = await mock()
393+
self.assertIs(result, value)
394+
395+
async def test_return_value_awaitable(self):
396+
fut = asyncio.Future()
397+
fut.set_result(None)
398+
mock = AsyncMock(return_value=fut)
399+
result = await mock()
400+
self.assertIsInstance(result, asyncio.Future)
401+
402+
async def test_side_effect_awaitable_values(self):
403+
fut = asyncio.Future()
404+
fut.set_result(None)
405+
406+
mock = AsyncMock(side_effect=[fut])
407+
result = await mock()
408+
self.assertIsInstance(result, asyncio.Future)
409+
410+
with self.assertRaises(StopAsyncIteration):
411+
await mock()
412+
413+
async def test_side_effect_is_AsyncMock(self):
414+
effect = AsyncMock(return_value=10)
415+
mock = AsyncMock(side_effect=effect)
416+
417+
result = await mock()
418+
self.assertEqual(result, 10)
419+
420+
async def test_wraps_coroutine(self):
421+
value = asyncio.Future()
422+
423+
ran = False
424+
async def inner():
425+
nonlocal ran
426+
ran = True
427+
return value
428+
429+
mock = AsyncMock(wraps=inner)
430+
result = await mock()
431+
self.assertEqual(result, value)
432+
mock.assert_awaited()
433+
self.assertTrue(ran)
392434

393435
class AsyncMagicMethods(unittest.TestCase):
394436
def test_async_magic_methods_return_async_mocks(self):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
AsyncMock fix for return values that are awaitable types. This also covers
2+
side_effect iterable values that happend to be awaitable, and wraps
3+
callables that return an awaitable type. Before these awaitables were being
4+
awaited instead of being returned as is.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects
2+
iterable. Since PEP-479 its Impossible to raise a StopIteration exception
3+
from a coroutine.

0 commit comments

Comments
 (0)