Skip to content

Commit d9af9af

Browse files
committed
validate_schema: inline get_all_nodes function
Replicates graphql/graphql-js@302f4b9
1 parent e820226 commit d9af9af

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

src/graphql/type/validate.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from operator import attrgetter, itemgetter
22
from typing import (
33
Any,
4-
Callable,
54
Collection,
65
Dict,
76
List,
@@ -17,10 +16,13 @@
1716
from ..language import (
1817
DirectiveNode,
1918
InputValueDefinitionNode,
19+
InterfaceTypeExtensionNode,
2020
NamedTypeNode,
2121
Node,
22+
ObjectTypeExtensionNode,
2223
OperationType,
23-
OperationTypeDefinitionNode,
24+
SchemaExtensionNode,
25+
UnionTypeExtensionNode,
2426
)
2527
from .definition import (
2628
GraphQLEnumType,
@@ -498,13 +500,12 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
498500
def get_operation_type_node(
499501
schema: GraphQLSchema, operation: OperationType
500502
) -> Optional[Node]:
501-
operation_nodes = cast(
502-
List[OperationTypeDefinitionNode],
503-
get_all_sub_nodes(schema, attrgetter("operation_types")),
504-
)
505-
for node in operation_nodes:
506-
if node.operation == operation:
507-
return node.type
503+
for extension_node in get_all_nodes(schema):
504+
operation_types = cast(SchemaExtensionNode, extension_node).operation_types
505+
if operation_types: # pragma: no cover else
506+
for operation_type in operation_types:
507+
if operation_type.operation == operation:
508+
return operation_type.type
508509
return None
509510

510511

@@ -580,39 +581,36 @@ def get_all_nodes(obj: SDLDefinedObject) -> List[Node]:
580581
return nodes
581582

582583

583-
def get_all_sub_nodes(
584-
obj: SDLDefinedObject, getter: Callable[[Node], List[Node]]
585-
) -> List[Node]:
586-
result: List[Node] = []
587-
for ast_node in get_all_nodes(obj):
588-
sub_nodes = getter(ast_node)
589-
if sub_nodes: # pragma: no cover
590-
result.extend(sub_nodes)
591-
return result
592-
593-
594584
def get_all_implements_interface_nodes(
595585
type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
596586
) -> List[NamedTypeNode]:
597-
implements_nodes = cast(
598-
List[NamedTypeNode], get_all_sub_nodes(type_, attrgetter("interfaces"))
599-
)
600-
return [
601-
iface_node
602-
for iface_node in implements_nodes
603-
if iface_node.name.value == iface.name
604-
]
587+
implements_nodes: List[NamedTypeNode] = []
588+
for extension_node in get_all_nodes(type_):
589+
iface_nodes = cast(
590+
Union[ObjectTypeExtensionNode, InterfaceTypeExtensionNode], extension_node
591+
).interfaces
592+
if iface_nodes: # pragma: no cover else
593+
implements_nodes.extend(
594+
iface_node
595+
for iface_node in iface_nodes
596+
if iface_node.name.value == iface.name
597+
)
598+
return implements_nodes
605599

606600

607601
def get_union_member_type_nodes(
608602
union: GraphQLUnionType, type_name: str
609603
) -> Optional[List[NamedTypeNode]]:
610-
union_nodes = cast(
611-
List[NamedTypeNode], get_all_sub_nodes(union, attrgetter("types"))
612-
)
613-
return [
614-
union_node for union_node in union_nodes if union_node.name.value == type_name
615-
]
604+
member_type_nodes: List[NamedTypeNode] = []
605+
for extension_node in get_all_nodes(union):
606+
type_nodes = cast(UnionTypeExtensionNode, extension_node).types
607+
if type_nodes: # pragma: no cover else
608+
member_type_nodes.extend(
609+
type_node
610+
for type_node in type_nodes
611+
if type_node.name.value == type_name
612+
)
613+
return member_type_nodes
616614

617615

618616
def get_deprecated_directive_node(

0 commit comments

Comments
 (0)