Skip to content

Commit ab33d4c

Browse files
committed
polish: add tests for assert_equal_awaitables_or_values
Replicates graphql/graphql-js@a842678
1 parent b47c922 commit ab33d4c

File tree

6 files changed

+114
-34
lines changed

6 files changed

+114
-34
lines changed

tests/execution/test_subscribe.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
import asyncio
2-
from typing import (
3-
Any,
4-
AsyncIterable,
5-
Awaitable,
6-
Callable,
7-
Dict,
8-
List,
9-
Optional,
10-
TypeVar,
11-
Union,
12-
cast,
13-
)
2+
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, TypeVar, Union
143

154
from pytest import mark, raises
165

@@ -33,6 +22,8 @@
3322
GraphQLString,
3423
)
3524

25+
from ..utils.assert_equal_awaitables_or_values import assert_equal_awaitables_or_values
26+
3627

3728
try:
3829
from typing import TypedDict
@@ -150,27 +141,6 @@ def transform(new_email):
150141
DummyQueryType = GraphQLObjectType("Query", {"dummy": GraphQLField(GraphQLString)})
151142

152143

153-
def assert_equal_awaitables_or_values(
154-
value1: AwaitableOrValue[T], value2: AwaitableOrValue[T]
155-
) -> AwaitableOrValue[T]:
156-
if is_awaitable(value1):
157-
awaitable1 = cast(Awaitable[T], value1)
158-
assert is_awaitable(value2)
159-
awaitable2 = cast(Awaitable[T], value2)
160-
161-
# noinspection PyShadowingNames
162-
async def awaited_equal_value():
163-
value1 = await awaitable1
164-
value2 = await awaitable2
165-
assert value1 == value2
166-
return value1
167-
168-
return awaited_equal_value()
169-
assert not is_awaitable(value2)
170-
assert value1 == value2
171-
return value1
172-
173-
174144
def subscribe_with_bad_fn(
175145
subscribe_fn: Callable,
176146
) -> AwaitableOrValue[Union[ExecutionResult, AsyncIterable[Any]]]:

tests/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
"""Test utilities"""
22

3+
from .assert_equal_awaitables_or_values import assert_equal_awaitables_or_values
4+
from .assert_matching_values import assert_matching_values
35
from .dedent import dedent
46
from .gen_fuzz_strings import gen_fuzz_strings
57

68

7-
__all__ = ["dedent", "gen_fuzz_strings"]
9+
__all__ = [
10+
"assert_matching_values",
11+
"assert_equal_awaitables_or_values",
12+
"dedent",
13+
"gen_fuzz_strings",
14+
]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import asyncio
2+
from typing import Awaitable, Tuple, TypeVar, cast
3+
4+
from graphql.pyutils import is_awaitable
5+
6+
from .assert_matching_values import assert_matching_values
7+
8+
9+
__all__ = ["assert_equal_awaitables_or_values"]
10+
11+
T = TypeVar("T")
12+
13+
14+
def assert_equal_awaitables_or_values(*items: T) -> T:
15+
"""Check whether the items are the same and either all awaitables or all values."""
16+
if all(is_awaitable(item) for item in items):
17+
awaitable_items = cast(Tuple[Awaitable], items)
18+
19+
async def assert_matching_awaitables():
20+
return assert_matching_values(*(await asyncio.gather(*awaitable_items)))
21+
22+
return assert_matching_awaitables()
23+
24+
if all(not is_awaitable(item) for item in items):
25+
return assert_matching_values(*items)
26+
27+
assert False, "Received an invalid mixture of promises and values."

tests/utils/assert_matching_values.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import TypeVar
2+
3+
4+
__all__ = ["assert_matching_values"]
5+
6+
T = TypeVar("T")
7+
8+
9+
def assert_matching_values(*values: T) -> T:
10+
"""Test that all values in the sequence are equal."""
11+
first_value, *remaining_values = values
12+
for value in remaining_values:
13+
assert value == first_value
14+
return first_value
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pytest import mark, raises
2+
3+
from . import assert_equal_awaitables_or_values
4+
5+
6+
def describe_assert_equal_awaitables_or_values():
7+
def throws_when_given_unequal_values():
8+
with raises(AssertionError):
9+
assert_equal_awaitables_or_values({}, {}, {"test": "test"})
10+
11+
def does_not_throw_when_given_equal_values():
12+
test_value = {"test": "test"}
13+
assert (
14+
assert_equal_awaitables_or_values(test_value, test_value, test_value)
15+
== test_value
16+
)
17+
18+
@mark.asyncio
19+
async def does_not_throw_when_given_equal_awaitables():
20+
async def test_value():
21+
return {"test": "test"}
22+
23+
assert (
24+
await assert_equal_awaitables_or_values(
25+
test_value(), test_value(), test_value()
26+
)
27+
== await test_value()
28+
)
29+
30+
@mark.asyncio
31+
async def throws_when_given_unequal_awaitables():
32+
async def test_value(value):
33+
return value
34+
35+
with raises(AssertionError):
36+
await assert_equal_awaitables_or_values(
37+
test_value({}), test_value({}), test_value({"test": "test"})
38+
)
39+
40+
@mark.asyncio
41+
async def throws_when_given_mixture_of_equal_values_and_awaitables():
42+
async def test_value():
43+
return {"test": "test"}
44+
45+
with raises(
46+
AssertionError,
47+
match=r"Received an invalid mixture of promises and values\.",
48+
):
49+
await assert_equal_awaitables_or_values(await test_value(), test_value())
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pytest import raises
2+
3+
from . import assert_matching_values
4+
5+
6+
def describe_assert_matching_values():
7+
def throws_when_given_unequal_values():
8+
with raises(AssertionError):
9+
assert_matching_values({}, {}, {"test": "test"})
10+
11+
def does_not_throw_when_given_equal_values():
12+
test_value = {"test": "test"}
13+
assert assert_matching_values(test_value, test_value, test_value) == test_value

0 commit comments

Comments
 (0)