Skip to content

Commit a96eacd

Browse files
committed
Resolve types in parallel
1 parent 39b20ea commit a96eacd

File tree

4 files changed

+199
-112
lines changed

4 files changed

+199
-112
lines changed

graphql/execution/execute.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,25 +1119,26 @@ def default_resolve_type_fn(
11191119

11201120
# Otherwise, test each possible type.
11211121
possible_types = info.schema.get_possible_types(abstract_type)
1122-
is_type_of_results_async = []
1122+
awaitable_is_type_of_results: List[Awaitable] = []
1123+
append_awaitable_results = awaitable_is_type_of_results.append
1124+
awaitable_types: List[GraphQLObjectType] = []
1125+
append_awaitable_types = awaitable_types.append
11231126

11241127
for type_ in possible_types:
11251128
if type_.is_type_of:
11261129
is_type_of_result = type_.is_type_of(value, info)
11271130

11281131
if isawaitable(is_type_of_result):
1129-
is_type_of_results_async.append((is_type_of_result, type_))
1132+
append_awaitable_results(cast(Awaitable, is_type_of_result))
1133+
append_awaitable_types(type_)
11301134
elif is_type_of_result:
11311135
return type_
11321136

1133-
if is_type_of_results_async:
1137+
if awaitable_is_type_of_results:
11341138
# noinspection PyShadowingNames
11351139
async def get_type():
1136-
is_type_of_results = [
1137-
(await is_type_of_result, type_)
1138-
for is_type_of_result, type_ in is_type_of_results_async
1139-
]
1140-
for is_type_of_result, type_ in is_type_of_results:
1140+
is_type_of_results = await gather(*awaitable_is_type_of_results)
1141+
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
11411142
if is_type_of_result:
11421143
return type_
11431144

tests/execution/test_customize.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from graphql.execution import execute, ExecutionContext
2+
from graphql.language import parse
3+
from graphql.type import GraphQLSchema, GraphQLObjectType, GraphQLString, GraphQLField
4+
5+
6+
def describe_customize_execution():
7+
def uses_a_custom_field_resolver():
8+
query = parse("{ foo }")
9+
10+
schema = GraphQLSchema(
11+
GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)})
12+
)
13+
14+
# For the purposes of test, just return the name of the field!
15+
def custom_resolver(_source, info, **_args):
16+
return info.field_name
17+
18+
assert execute(schema, query, field_resolver=custom_resolver) == (
19+
{"foo": "foo"},
20+
None,
21+
)
22+
23+
def uses_a_custom_execution_context_class():
24+
query = parse("{ foo }")
25+
26+
schema = GraphQLSchema(
27+
GraphQLObjectType(
28+
"Query",
29+
{"foo": GraphQLField(GraphQLString, resolve=lambda *_args: "bar")},
30+
)
31+
)
32+
33+
class TestExecutionContext(ExecutionContext):
34+
def resolve_field(self, parent_type, source, field_nodes, path):
35+
result = super().resolve_field(parent_type, source, field_nodes, path)
36+
return result * 2
37+
38+
assert execute(schema, query, execution_context_class=TestExecutionContext) == (
39+
{"foo": "barbar"},
40+
None,
41+
)

tests/execution/test_executor.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytest import raises, mark
66

77
from graphql.error import GraphQLError
8-
from graphql.execution import execute, ExecutionContext
8+
from graphql.execution import execute
99
from graphql.language import parse, OperationDefinitionNode, FieldNode
1010
from graphql.type import (
1111
GraphQLSchema,
@@ -830,106 +830,3 @@ def executes_ignoring_invalid_non_executable_definitions():
830830
)
831831

832832
assert execute(schema, query) == ({"foo": None}, None)
833-
834-
835-
def describe_customize_execution():
836-
def uses_a_custom_field_resolver():
837-
query = parse("{ foo }")
838-
839-
schema = GraphQLSchema(
840-
GraphQLObjectType("Query", {"foo": GraphQLField(GraphQLString)})
841-
)
842-
843-
# For the purposes of test, just return the name of the field!
844-
def custom_resolver(_source, info, **_args):
845-
return info.field_name
846-
847-
assert execute(schema, query, field_resolver=custom_resolver) == (
848-
{"foo": "foo"},
849-
None,
850-
)
851-
852-
def uses_a_custom_execution_context_class():
853-
query = parse("{ foo }")
854-
855-
schema = GraphQLSchema(
856-
GraphQLObjectType(
857-
"Query",
858-
{"foo": GraphQLField(GraphQLString, resolve=lambda *_args: "bar")},
859-
)
860-
)
861-
862-
class TestExecutionContext(ExecutionContext):
863-
def resolve_field(self, parent_type, source, field_nodes, path):
864-
result = super().resolve_field(parent_type, source, field_nodes, path)
865-
return result * 2
866-
867-
assert execute(schema, query, execution_context_class=TestExecutionContext) == (
868-
{"foo": "barbar"},
869-
None,
870-
)
871-
872-
873-
def describe_parallel_execution():
874-
class Barrier:
875-
"""Barrier that makes progress only after a certain number of waits."""
876-
877-
def __init__(self, number: int) -> None:
878-
self.event = asyncio.Event()
879-
self.number = number
880-
881-
async def wait(self) -> bool:
882-
self.number -= 1
883-
if not self.number:
884-
self.event.set()
885-
return await self.event.wait()
886-
887-
@mark.asyncio
888-
async def resolve_fields_in_parallel():
889-
barrier = Barrier(2)
890-
891-
async def resolve(*_args):
892-
return await barrier.wait()
893-
894-
schema = GraphQLSchema(
895-
GraphQLObjectType(
896-
"Query",
897-
{
898-
"foo": GraphQLField(GraphQLBoolean, resolve=resolve),
899-
"bar": GraphQLField(GraphQLBoolean, resolve=resolve),
900-
},
901-
)
902-
)
903-
904-
ast = parse("{foo, bar}")
905-
# raises TimeoutError if not parallel
906-
result = await asyncio.wait_for(execute(schema, ast), 1.0)
907-
908-
assert result == ({"foo": True, "bar": True}, None)
909-
910-
@mark.asyncio
911-
async def resolve_list_in_parallel():
912-
barrier = Barrier(2)
913-
914-
async def resolve(*_args):
915-
return await barrier.wait()
916-
917-
async def resolve_list(*args):
918-
return [resolve(*args), resolve(*args)]
919-
920-
schema = GraphQLSchema(
921-
GraphQLObjectType(
922-
"Query",
923-
{
924-
"foo": GraphQLField(
925-
GraphQLList(GraphQLBoolean), resolve=resolve_list
926-
)
927-
},
928-
)
929-
)
930-
931-
ast = parse("{foo}")
932-
# raises TimeoutError if not parallel
933-
result = await asyncio.wait_for(execute(schema, ast), 1.0)
934-
935-
assert result == ({"foo": [True, True]}, None)

tests/execution/test_parallel.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
3+
from pytest import mark
4+
5+
from graphql.execution import execute
6+
from graphql.language import parse
7+
from graphql.type import (
8+
GraphQLSchema,
9+
GraphQLObjectType,
10+
GraphQLField,
11+
GraphQLList,
12+
GraphQLInterfaceType,
13+
GraphQLBoolean,
14+
GraphQLInt,
15+
GraphQLString,
16+
)
17+
18+
19+
class Barrier:
20+
"""Barrier that makes progress only after a certain number of waits."""
21+
22+
def __init__(self, number: int) -> None:
23+
self.event = asyncio.Event()
24+
self.number = number
25+
26+
async def wait(self) -> bool:
27+
self.number -= 1
28+
if not self.number:
29+
self.event.set()
30+
return await self.event.wait()
31+
32+
33+
def describe_parallel_execution():
34+
@mark.asyncio
35+
async def resolve_fields_in_parallel():
36+
barrier = Barrier(2)
37+
38+
async def resolve(*_args):
39+
return await barrier.wait()
40+
41+
schema = GraphQLSchema(
42+
GraphQLObjectType(
43+
"Query",
44+
{
45+
"foo": GraphQLField(GraphQLBoolean, resolve=resolve),
46+
"bar": GraphQLField(GraphQLBoolean, resolve=resolve),
47+
},
48+
)
49+
)
50+
51+
ast = parse("{foo, bar}")
52+
53+
# raises TimeoutError if not parallel
54+
result = await asyncio.wait_for(execute(schema, ast), 1.0)
55+
56+
assert result == ({"foo": True, "bar": True}, None)
57+
58+
@mark.asyncio
59+
async def resolve_list_in_parallel():
60+
barrier = Barrier(2)
61+
62+
async def resolve(*_args):
63+
return await barrier.wait()
64+
65+
async def resolve_list(*args):
66+
return [resolve(*args), resolve(*args)]
67+
68+
schema = GraphQLSchema(
69+
GraphQLObjectType(
70+
"Query",
71+
{
72+
"foo": GraphQLField(
73+
GraphQLList(GraphQLBoolean), resolve=resolve_list
74+
)
75+
},
76+
)
77+
)
78+
79+
ast = parse("{foo}")
80+
81+
# raises TimeoutError if not parallel
82+
result = await asyncio.wait_for(execute(schema, ast), 1.0)
83+
84+
assert result == ({"foo": [True, True]}, None)
85+
86+
@mark.asyncio
87+
async def resolve_is_type_of_in_parallel():
88+
FooType = GraphQLInterfaceType("Foo", {"foo": GraphQLField(GraphQLString)})
89+
90+
barrier = Barrier(4)
91+
92+
async def is_type_of_bar(obj, *_args):
93+
await barrier.wait()
94+
return obj["foo"] == "bar"
95+
96+
BarType = GraphQLObjectType(
97+
"Bar",
98+
{"foo": GraphQLField(GraphQLString), "foobar": GraphQLField(GraphQLInt)},
99+
interfaces=[FooType],
100+
is_type_of=is_type_of_bar,
101+
)
102+
103+
async def is_type_of_baz(obj, *_args):
104+
await barrier.wait()
105+
return obj["foo"] == "baz"
106+
107+
BazType = GraphQLObjectType(
108+
"Baz",
109+
{"foo": GraphQLField(GraphQLString), "foobaz": GraphQLField(GraphQLInt)},
110+
interfaces=[FooType],
111+
is_type_of=is_type_of_baz,
112+
)
113+
114+
schema = GraphQLSchema(
115+
GraphQLObjectType(
116+
"Query",
117+
{
118+
"foo": GraphQLField(
119+
GraphQLList(FooType),
120+
resolve=lambda *_args: [
121+
{"foo": "bar", "foobar": 1},
122+
{"foo": "baz", "foobaz": 2},
123+
],
124+
)
125+
},
126+
),
127+
types=[BarType, BazType],
128+
)
129+
130+
ast = parse(
131+
"""
132+
{
133+
foo {
134+
foo
135+
... on Bar { foobar }
136+
... on Baz { foobaz }
137+
}
138+
}
139+
"""
140+
)
141+
142+
# raises TimeoutError if not parallel
143+
result = await asyncio.wait_for(execute(schema, ast), 1.0)
144+
145+
assert result == (
146+
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
147+
None,
148+
)

0 commit comments

Comments
 (0)