Skip to content

Commit 2792f59

Browse files
committed
Add new GraphQLSchema.get_field method
Replicates graphql/graphql-js@69e1554
1 parent aa77082 commit 2792f59

File tree

6 files changed

+187
-57
lines changed

6 files changed

+187
-57
lines changed

src/graphql/execution/execute.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@
5050
GraphQLResolveInfo,
5151
GraphQLSchema,
5252
GraphQLTypeResolver,
53-
SchemaMetaFieldDef,
54-
TypeMetaFieldDef,
55-
TypeNameMetaFieldDef,
5653
assert_valid_schema,
5754
is_abstract_type,
5855
is_leaf_type,
@@ -71,7 +68,6 @@
7168
"default_type_resolver",
7269
"execute",
7370
"execute_sync",
74-
"get_field_def",
7571
"ExecutionResult",
7672
"ExecutionContext",
7773
"FormattedExecutionResult",
@@ -501,7 +497,8 @@ def execute_field(
501497
calling its resolve function, then calls complete_value to await coroutine
502498
objects, serialize scalars, or execute the sub-selection-set for objects.
503499
"""
504-
field_def = get_field_def(self.schema, parent_type, field_nodes[0])
500+
field_name = field_nodes[0].name.value
501+
field_def = self.schema.get_field(parent_type, field_name)
505502
if not field_def:
506503
return Undefined
507504

@@ -1130,31 +1127,6 @@ def assert_valid_execution_arguments(
11301127
)
11311128

11321129

1133-
def get_field_def(
1134-
schema: GraphQLSchema, parent_type: GraphQLObjectType, field_node: FieldNode
1135-
) -> GraphQLField:
1136-
"""Get field definition.
1137-
1138-
This method looks up the field on the given type definition. It has special casing
1139-
for the three introspection fields, ``__schema``, ``__type`, and ``__typename``.
1140-
``__typename`` is special because it can always be queried as a field, even in
1141-
situations where no other fields are allowed, like on a Union. ``__schema`` and
1142-
``__type`` could get automatically added to the query type, but that would require
1143-
mutating type definitions, which would cause issues.
1144-
1145-
For internal use only.
1146-
"""
1147-
field_name = field_node.name.value
1148-
1149-
if field_name == "__schema" and schema.query_type == parent_type:
1150-
return SchemaMetaFieldDef
1151-
elif field_name == "__type" and schema.query_type == parent_type:
1152-
return TypeMetaFieldDef
1153-
elif field_name == "__typename":
1154-
return TypeNameMetaFieldDef
1155-
return parent_type.fields.get(field_name)
1156-
1157-
11581130
def invalid_return_type_error(
11591131
return_type: GraphQLObjectType, result: Any, field_nodes: List[FieldNode]
11601132
) -> GraphQLError:

src/graphql/execution/subscribe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
ExecutionResult,
99
assert_valid_execution_arguments,
1010
execute,
11-
get_field_def,
1211
)
1312
from ..execution.values import get_argument_values
1413
from ..language import DocumentNode
@@ -179,10 +178,10 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
179178
context.operation.selection_set,
180179
)
181180
response_name, field_nodes = next(iter(root_fields.items()))
182-
field_def = get_field_def(schema, root_type, field_nodes[0])
181+
field_name = field_nodes[0].name.value
182+
field_def = schema.get_field(root_type, field_name)
183183

184184
if not field_def:
185-
field_name = field_nodes[0].name.value
186185
raise GraphQLError(
187186
f"The subscription field '{field_name}' is not defined.", field_nodes
188187
)

src/graphql/type/schema.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from ..pyutils import inspect, is_collection, is_description
99
from .definition import (
1010
GraphQLAbstractType,
11+
GraphQLCompositeType,
12+
GraphQLField,
1113
GraphQLInterfaceType,
1214
GraphQLNamedType,
1315
GraphQLObjectType,
@@ -20,7 +22,12 @@
2022
is_wrapping_type,
2123
)
2224
from .directives import GraphQLDirective, is_directive, specified_directives
23-
from .introspection import introspection_types
25+
from .introspection import (
26+
SchemaMetaFieldDef,
27+
TypeMetaFieldDef,
28+
TypeNameMetaFieldDef,
29+
introspection_types,
30+
)
2431

2532

2633
try:
@@ -387,6 +394,35 @@ def get_directive(self, name: str) -> Optional[GraphQLDirective]:
387394
return directive
388395
return None
389396

397+
def get_field(
398+
self, parent_type: GraphQLCompositeType, field_name: str
399+
) -> Optional[GraphQLField]:
400+
"""Get field of a given type with the given name.
401+
402+
This method looks up the field on the given type definition.
403+
It has special casing for the three introspection fields, `__schema`,
404+
`__type` and `__typename`.
405+
406+
`__typename` is special because it can always be queried as a field, even
407+
in situations where no other fields are allowed, like on a Union.
408+
409+
`__schema` and `__type` could get automatically added to the query type,
410+
but that would require mutating type definitions, which would cause issues.
411+
"""
412+
if field_name == "__schema":
413+
return SchemaMetaFieldDef if self.query_type is parent_type else None
414+
if field_name == "__type":
415+
return TypeMetaFieldDef if self.query_type is parent_type else None
416+
if field_name == "__typename":
417+
return TypeNameMetaFieldDef
418+
419+
# this function is part of a "hot" path inside executor and to assume presence
420+
# of 'fields' is faster than to use `not is_union_type`
421+
try:
422+
return parent_type.fields[field_name] # type: ignore
423+
except (AttributeError, KeyError):
424+
return None
425+
390426
@property
391427
def validation_errors(self) -> Optional[List[GraphQLError]]:
392428
return self._validation_errors

src/graphql/utilities/type_info.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,12 @@
2727
GraphQLOutputType,
2828
GraphQLSchema,
2929
GraphQLType,
30-
SchemaMetaFieldDef,
31-
TypeMetaFieldDef,
32-
TypeNameMetaFieldDef,
3330
get_named_type,
3431
get_nullable_type,
3532
is_composite_type,
3633
is_enum_type,
3734
is_input_object_type,
3835
is_input_type,
39-
is_interface_type,
4036
is_list_type,
4137
is_object_type,
4238
is_output_type,
@@ -54,7 +50,7 @@
5450

5551

5652
GetFieldDefFn: TypeAlias = Callable[
57-
[GraphQLSchema, GraphQLType, FieldNode], Optional[GraphQLField]
53+
[GraphQLSchema, GraphQLCompositeType, FieldNode], Optional[GraphQLField]
5854
]
5955

6056

@@ -264,24 +260,9 @@ def leave_enum_value(self) -> None:
264260

265261

266262
def get_field_def(
267-
schema: GraphQLSchema, parent_type: GraphQLType, field_node: FieldNode
263+
schema: GraphQLSchema, parent_type: GraphQLCompositeType, field_node: FieldNode
268264
) -> Optional[GraphQLField]:
269-
"""Get field definition.
270-
271-
Not exactly the same as the executor's definition of
272-
:func:`graphql.execution.get_field_def`, in this statically evaluated environment
273-
we do not always have an Object type, and need to handle Interface and Union types.
274-
"""
275-
name = field_node.name.value
276-
if name == "__schema" and schema.query_type is parent_type:
277-
return SchemaMetaFieldDef
278-
if name == "__type" and schema.query_type is parent_type:
279-
return TypeMetaFieldDef
280-
if name == "__typename" and is_composite_type(parent_type):
281-
return TypeNameMetaFieldDef
282-
if is_object_type(parent_type) or is_interface_type(parent_type):
283-
return parent_type.fields.get(name)
284-
return None
265+
return schema.get_field(parent_type, field_node.name.value)
285266

286267

287268
class TypeInfoVisitor(Visitor):

tests/type/test_schema.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
GraphQLSchema,
2727
GraphQLString,
2828
GraphQLType,
29+
GraphQLUnionType,
30+
SchemaMetaFieldDef,
31+
TypeMetaFieldDef,
32+
TypeNameMetaFieldDef,
2933
specified_directives,
3034
)
3135
from graphql.utilities import build_schema, lexicographic_sort_schema, print_schema
@@ -293,6 +297,60 @@ def preserves_the_order_of_user_provided_types():
293297
copy_schema = GraphQLSchema(**schema.to_kwargs())
294298
assert list(copy_schema.type_map) == type_names
295299

300+
def describe_get_field():
301+
pet_type = GraphQLInterfaceType("Pet", {"name": GraphQLField(GraphQLString)})
302+
cat_type = GraphQLObjectType(
303+
"Cat", {"name": GraphQLField(GraphQLString)}, [pet_type]
304+
)
305+
dog_type = GraphQLObjectType(
306+
"Dog", {"name": GraphQLField(GraphQLString)}, [pet_type]
307+
)
308+
cat_or_dog = GraphQLUnionType("CatOrDog", [cat_type, dog_type])
309+
query_type = GraphQLObjectType("Query", {"catOrDog": GraphQLField(cat_or_dog)})
310+
mutation_type = GraphQLObjectType("Mutation", {})
311+
subscription_type = GraphQLObjectType("Subscription", {})
312+
schema = GraphQLSchema(query_type, mutation_type, subscription_type)
313+
314+
_get_field = schema.get_field
315+
316+
def returns_known_field():
317+
assert _get_field(pet_type, "name") == pet_type.fields["name"]
318+
assert _get_field(cat_type, "name") == cat_type.fields["name"]
319+
320+
assert _get_field(query_type, "catOrDog") == query_type.fields["catOrDog"]
321+
322+
def returns_none_for_unknown_fields():
323+
assert _get_field(cat_or_dog, "name") is None
324+
325+
assert _get_field(query_type, "unknown") is None
326+
assert _get_field(pet_type, "unknown") is None
327+
assert _get_field(cat_type, "unknown") is None
328+
assert _get_field(cat_or_dog, "unknown") is None
329+
330+
def handles_introspection_fields():
331+
assert _get_field(query_type, "__typename") == TypeNameMetaFieldDef
332+
assert _get_field(mutation_type, "__typename") == TypeNameMetaFieldDef
333+
assert _get_field(subscription_type, "__typename") == TypeNameMetaFieldDef
334+
335+
assert _get_field(pet_type, "__typename") is TypeNameMetaFieldDef
336+
assert _get_field(cat_type, "__typename") is TypeNameMetaFieldDef
337+
assert _get_field(dog_type, "__typename") is TypeNameMetaFieldDef
338+
assert _get_field(cat_or_dog, "__typename") is TypeNameMetaFieldDef
339+
340+
assert _get_field(query_type, "__type") is TypeMetaFieldDef
341+
assert _get_field(query_type, "__schema") is SchemaMetaFieldDef
342+
343+
def returns_non_for_introspection_fields_in_wrong_location():
344+
assert _get_field(pet_type, "__type") is None
345+
assert _get_field(dog_type, "__type") is None
346+
assert _get_field(mutation_type, "__type") is None
347+
assert _get_field(subscription_type, "__type") is None
348+
349+
assert _get_field(pet_type, "__schema") is None
350+
assert _get_field(dog_type, "__schema") is None
351+
assert _get_field(mutation_type, "__schema") is None
352+
assert _get_field(subscription_type, "__schema") is None
353+
296354
def describe_validity():
297355
def describe_when_not_assumed_valid():
298356
def configures_the_schema_to_still_needing_validation():

tests/utilities/test_type_info.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Optional, Tuple
2+
13
from graphql.language import (
24
FieldNode,
35
NameNode,
@@ -10,7 +12,14 @@
1012
print_ast,
1113
visit,
1214
)
13-
from graphql.type import GraphQLSchema, get_named_type, is_composite_type
15+
from graphql.type import (
16+
GraphQLSchema,
17+
SchemaMetaFieldDef,
18+
TypeMetaFieldDef,
19+
TypeNameMetaFieldDef,
20+
get_named_type,
21+
is_composite_type,
22+
)
1423
from graphql.utilities import TypeInfo, TypeInfoVisitor, build_schema
1524

1625
from ..fixtures import kitchen_sink_query # noqa: F401
@@ -39,9 +48,13 @@
3948
name(surname: Boolean): String
4049
}
4150
51+
union HumanOrAlien = Human | Alien
52+
4253
type QueryRoot {
4354
human(id: ID): Human
4455
alien: Alien
56+
humanOrAlien: HumanOrAlien
57+
pet: Pet
4558
}
4659
4760
schema {
@@ -140,6 +153,77 @@ def leave(self, *args):
140153

141154
assert test_visitor.args == wrapped_visitor.args
142155

156+
def supports_introspection_fields():
157+
type_info = TypeInfo(test_schema)
158+
159+
ast = parse(
160+
"""
161+
{
162+
__typename
163+
__type(name: "Cat") { __typename }
164+
__schema {
165+
__typename # in object type
166+
}
167+
humanOrAlien {
168+
__typename # in union type
169+
}
170+
pet {
171+
__typename # in interface type
172+
}
173+
someUnknownType {
174+
__typename # unknown
175+
}
176+
pet {
177+
__type # unknown
178+
__schema # unknown
179+
}
180+
}
181+
"""
182+
)
183+
184+
visited_fields: List[Tuple[Optional[str], Optional[str]]] = []
185+
186+
class TestVisitor(Visitor):
187+
@staticmethod
188+
def enter_field(self, node: OperationDefinitionNode, *_args):
189+
parent_type = type_info.get_parent_type()
190+
type_name = getattr(type_info.get_parent_type(), "name", None)
191+
field_def = type_info.get_field_def()
192+
fields = getattr(parent_type, "fields", {})
193+
fields = dict(
194+
**fields,
195+
__type=TypeMetaFieldDef,
196+
__typename=TypeNameMetaFieldDef,
197+
__schema=SchemaMetaFieldDef,
198+
)
199+
for name, field in fields.items():
200+
if field is field_def:
201+
field_name = name
202+
break
203+
else:
204+
field_name = None
205+
visited_fields.append((type_name, field_name))
206+
207+
test_visitor = TestVisitor()
208+
assert visit(ast, TypeInfoVisitor(type_info, test_visitor))
209+
210+
assert visited_fields == [
211+
("QueryRoot", "__typename"),
212+
("QueryRoot", "__type"),
213+
("__Type", "__typename"),
214+
("QueryRoot", "__schema"),
215+
("__Schema", "__typename"),
216+
("QueryRoot", "humanOrAlien"),
217+
("HumanOrAlien", "__typename"),
218+
("QueryRoot", "pet"),
219+
("Pet", "__typename"),
220+
("QueryRoot", None),
221+
(None, None),
222+
("QueryRoot", "pet"),
223+
("Pet", None),
224+
("Pet", None),
225+
]
226+
143227
def maintains_type_info_during_visit():
144228
visited = []
145229

0 commit comments

Comments
 (0)