|
1 | 1 | from operator import attrgetter, itemgetter
|
2 | 2 | from typing import (
|
3 | 3 | Any,
|
4 |
| - Callable, |
5 | 4 | Collection,
|
6 | 5 | Dict,
|
7 | 6 | List,
|
|
17 | 16 | from ..language import (
|
18 | 17 | DirectiveNode,
|
19 | 18 | InputValueDefinitionNode,
|
| 19 | + InterfaceTypeExtensionNode, |
20 | 20 | NamedTypeNode,
|
21 | 21 | Node,
|
| 22 | + ObjectTypeExtensionNode, |
22 | 23 | OperationType,
|
23 |
| - OperationTypeDefinitionNode, |
| 24 | + SchemaExtensionNode, |
| 25 | + UnionTypeExtensionNode, |
24 | 26 | )
|
25 | 27 | from .definition import (
|
26 | 28 | GraphQLEnumType,
|
@@ -498,13 +500,12 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
|
498 | 500 | def get_operation_type_node(
|
499 | 501 | schema: GraphQLSchema, operation: OperationType
|
500 | 502 | ) -> Optional[Node]:
|
501 |
| - operation_nodes = cast( |
502 |
| - List[OperationTypeDefinitionNode], |
503 |
| - get_all_sub_nodes(schema, attrgetter("operation_types")), |
504 |
| - ) |
505 |
| - for node in operation_nodes: |
506 |
| - if node.operation == operation: |
507 |
| - return node.type |
| 503 | + for extension_node in get_all_nodes(schema): |
| 504 | + operation_types = cast(SchemaExtensionNode, extension_node).operation_types |
| 505 | + if operation_types: # pragma: no cover else |
| 506 | + for operation_type in operation_types: |
| 507 | + if operation_type.operation == operation: |
| 508 | + return operation_type.type |
508 | 509 | return None
|
509 | 510 |
|
510 | 511 |
|
@@ -580,39 +581,36 @@ def get_all_nodes(obj: SDLDefinedObject) -> List[Node]:
|
580 | 581 | return nodes
|
581 | 582 |
|
582 | 583 |
|
583 |
| -def get_all_sub_nodes( |
584 |
| - obj: SDLDefinedObject, getter: Callable[[Node], List[Node]] |
585 |
| -) -> List[Node]: |
586 |
| - result: List[Node] = [] |
587 |
| - for ast_node in get_all_nodes(obj): |
588 |
| - sub_nodes = getter(ast_node) |
589 |
| - if sub_nodes: # pragma: no cover |
590 |
| - result.extend(sub_nodes) |
591 |
| - return result |
592 |
| - |
593 |
| - |
594 | 584 | def get_all_implements_interface_nodes(
|
595 | 585 | type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
|
596 | 586 | ) -> List[NamedTypeNode]:
|
597 |
| - implements_nodes = cast( |
598 |
| - List[NamedTypeNode], get_all_sub_nodes(type_, attrgetter("interfaces")) |
599 |
| - ) |
600 |
| - return [ |
601 |
| - iface_node |
602 |
| - for iface_node in implements_nodes |
603 |
| - if iface_node.name.value == iface.name |
604 |
| - ] |
| 587 | + implements_nodes: List[NamedTypeNode] = [] |
| 588 | + for extension_node in get_all_nodes(type_): |
| 589 | + iface_nodes = cast( |
| 590 | + Union[ObjectTypeExtensionNode, InterfaceTypeExtensionNode], extension_node |
| 591 | + ).interfaces |
| 592 | + if iface_nodes: # pragma: no cover else |
| 593 | + implements_nodes.extend( |
| 594 | + iface_node |
| 595 | + for iface_node in iface_nodes |
| 596 | + if iface_node.name.value == iface.name |
| 597 | + ) |
| 598 | + return implements_nodes |
605 | 599 |
|
606 | 600 |
|
607 | 601 | def get_union_member_type_nodes(
|
608 | 602 | union: GraphQLUnionType, type_name: str
|
609 | 603 | ) -> Optional[List[NamedTypeNode]]:
|
610 |
| - union_nodes = cast( |
611 |
| - List[NamedTypeNode], get_all_sub_nodes(union, attrgetter("types")) |
612 |
| - ) |
613 |
| - return [ |
614 |
| - union_node for union_node in union_nodes if union_node.name.value == type_name |
615 |
| - ] |
| 604 | + member_type_nodes: List[NamedTypeNode] = [] |
| 605 | + for extension_node in get_all_nodes(union): |
| 606 | + type_nodes = cast(UnionTypeExtensionNode, extension_node).types |
| 607 | + if type_nodes: # pragma: no cover else |
| 608 | + member_type_nodes.extend( |
| 609 | + type_node |
| 610 | + for type_node in type_nodes |
| 611 | + if type_node.name.value == type_name |
| 612 | + ) |
| 613 | + return member_type_nodes |
616 | 614 |
|
617 | 615 |
|
618 | 616 | def get_deprecated_directive_node(
|
|
0 commit comments