diff --git a/src/graphql/language/visitor.py b/src/graphql/language/visitor.py index ce845949..3cb0f009 100644 --- a/src/graphql/language/visitor.py +++ b/src/graphql/language/visitor.py @@ -183,16 +183,15 @@ def __init_subclass__(cls) -> None: kind: Optional[str] = None else: attr, kind = attr_kind - if attr in ("enter", "leave"): - if kind: - name = snake_to_camel(kind) + "Node" - node_cls = getattr(ast, name, None) - if ( - not node_cls - or not isinstance(node_cls, type) - or not issubclass(node_cls, Node) - ): - raise TypeError(f"Invalid AST node kind: {kind}.") + if attr in ("enter", "leave") and kind: + name = snake_to_camel(kind) + "Node" + node_cls = getattr(ast, name, None) + if ( + not node_cls + or not isinstance(node_cls, type) + or not issubclass(node_cls, Node) + ): + raise TypeError(f"Invalid AST node kind: {kind}.") def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable: """Get the visit function for the given node kind and direction.""" @@ -256,10 +255,7 @@ def visit(root: Node, visitor: Visitor) -> Any: node: Any = parent parent = ancestors_pop() if ancestors else None if is_edited: - if in_array: - node = node[:] - else: - node = copy(node) + node = node[:] if in_array else copy(node) edit_offset = 0 for edit_key, edit_value in edits: if in_array: @@ -267,11 +263,10 @@ def visit(root: Node, visitor: Visitor) -> Any: if in_array and (edit_value is REMOVE or edit_value is Ellipsis): node.pop(edit_key) edit_offset += 1 + elif isinstance(node, list): + node[edit_key] = edit_value else: - if isinstance(node, list): - node[edit_key] = edit_value - else: - setattr(node, edit_key, edit_value) + setattr(node, edit_key, edit_value) idx = stack.idx keys = stack.keys diff --git a/src/graphql/pyutils/suggestion_list.py b/src/graphql/pyutils/suggestion_list.py index de78ec10..0020b3ce 100644 --- a/src/graphql/pyutils/suggestion_list.py +++ b/src/graphql/pyutils/suggestion_list.py @@ -73,7 +73,7 @@ def measure(self, option: str, threshold: int) -> Optional[int]: return None rows = self._rows - for j in range(0, b_len + 1): + for j in range(b_len + 1): rows[0][j] = j for i in range(1, a_len + 1): diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index e24490d9..43400fd3 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -33,8 +33,6 @@ async def __anext__(self) -> Any: if not isasyncgen(self.iterator): raise StopAsyncIteration value = await self.iterator.__anext__() - result = self.callback(value) - else: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) @@ -61,7 +59,8 @@ async def __anext__(self) -> Any: raise error value = anext.result() - result = self.callback(value) + + result = self.callback(value) return await result if isawaitable(result) else result @@ -72,23 +71,24 @@ async def athrow( traceback: Optional[TracebackType] = None, ) -> None: """Throw an exception into the asynchronous iterator.""" - if not self.is_closed: - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value + if self.is_closed: + return + athrow = getattr(self.iterator, "athrow", None) + if athrow: + await athrow(type_, value, traceback) + else: + await self.aclose() + if value is None: + if traceback is None: + raise type_ + value = ( + type_ + if isinstance(value, BaseException) + else cast(Type[BaseException], type_)() + ) + if traceback is not None: + value = value.with_traceback(traceback) + raise value async def aclose(self) -> None: """Close the iterator.""" diff --git a/src/graphql/utilities/print_schema.py b/src/graphql/utilities/print_schema.py index d9c54919..685ab65a 100644 --- a/src/graphql/utilities/print_schema.py +++ b/src/graphql/utilities/print_schema.py @@ -108,10 +108,7 @@ def is_schema_of_common_names(schema: GraphQLSchema) -> bool: return False subscription_type = schema.subscription_type - if subscription_type and subscription_type.name != "Subscription": - return False - - return True + return not subscription_type or subscription_type.name == "Subscription" def print_type(type_: GraphQLNamedType) -> str: diff --git a/src/graphql/utilities/strip_ignored_characters.py b/src/graphql/utilities/strip_ignored_characters.py index af69dd3a..860893b2 100644 --- a/src/graphql/utilities/strip_ignored_characters.py +++ b/src/graphql/utilities/strip_ignored_characters.py @@ -79,9 +79,10 @@ def strip_ignored_characters(source: Union[str, Source]) -> str: # Also prevent case of non-punctuator token following by spread resulting # in invalid token (e.g.`1...` is invalid Float token). is_non_punctuator = not is_punctuator_token_kind(current_token.kind) - if was_last_added_token_non_punctuator: - if is_non_punctuator or current_token.kind == TokenKind.SPREAD: - stripped_body += " " + if was_last_added_token_non_punctuator and ( + is_non_punctuator or current_token.kind == TokenKind.SPREAD + ): + stripped_body += " " token_body = body[current_token.start : current_token.end] if token_kind == TokenKind.BLOCK_STRING: diff --git a/src/graphql/validation/rules/lone_anonymous_operation.py b/src/graphql/validation/rules/lone_anonymous_operation.py index 3e945f8b..f14d5f5a 100644 --- a/src/graphql/validation/rules/lone_anonymous_operation.py +++ b/src/graphql/validation/rules/lone_anonymous_operation.py @@ -20,9 +20,8 @@ def __init__(self, context: ASTValidationContext): def enter_document(self, node: DocumentNode, *_args: Any) -> None: self.operation_count = sum( - 1 + isinstance(definition, OperationDefinitionNode) for definition in node.definitions - if isinstance(definition, OperationDefinitionNode) ) def enter_operation_definition( diff --git a/src/graphql/validation/rules/single_field_subscriptions.py b/src/graphql/validation/rules/single_field_subscriptions.py index 86b27122..d5e61178 100644 --- a/src/graphql/validation/rules/single_field_subscriptions.py +++ b/src/graphql/validation/rules/single_field_subscriptions.py @@ -23,39 +23,55 @@ class SingleFieldSubscriptionsRule(ValidationRule): def enter_operation_definition( self, node: OperationDefinitionNode, *_args: Any ) -> None: - if node.operation == OperationType.SUBSCRIPTION: - schema = self.context.schema - subscription_type = schema.subscription_type - if subscription_type: - operation_name = node.name.value if node.name else None - variable_values: Dict[str, Any] = {} - document = self.context.document - fragments: Dict[str, FragmentDefinitionNode] = { - definition.name.value: definition - for definition in document.definitions - if isinstance(definition, FragmentDefinitionNode) - } - fields = collect_fields( - schema, - fragments, - variable_values, - subscription_type, - node.selection_set, - {}, - set(), - ) - if len(fields) > 1: - field_selection_lists = list(fields.values()) - extra_field_selection_lists = field_selection_lists[1:] - extra_field_selection = [ - field - for fields in extra_field_selection_lists - for field in ( - fields - if isinstance(fields, list) - else [cast(FieldNode, fields)] + if node.operation != OperationType.SUBSCRIPTION: + return + schema = self.context.schema + subscription_type = schema.subscription_type + if subscription_type: + operation_name = node.name.value if node.name else None + variable_values: Dict[str, Any] = {} + document = self.context.document + fragments: Dict[str, FragmentDefinitionNode] = { + definition.name.value: definition + for definition in document.definitions + if isinstance(definition, FragmentDefinitionNode) + } + fields = collect_fields( + schema, + fragments, + variable_values, + subscription_type, + node.selection_set, + {}, + set(), + ) + if len(fields) > 1: + field_selection_lists = list(fields.values()) + extra_field_selection_lists = field_selection_lists[1:] + extra_field_selection = [ + field + for fields in extra_field_selection_lists + for field in ( + fields + if isinstance(fields, list) + else [cast(FieldNode, fields)] + ) + ] + self.report_error( + GraphQLError( + ( + "Anonymous Subscription" + if operation_name is None + else f"Subscription '{operation_name}'" ) - ] + + " must select only one top level field.", + extra_field_selection, + ) + ) + for field_nodes in fields.values(): + field = field_nodes[0] + field_name = field.name.value + if field_name.startswith("__"): self.report_error( GraphQLError( ( @@ -63,22 +79,7 @@ def enter_operation_definition( if operation_name is None else f"Subscription '{operation_name}'" ) - + " must select only one top level field.", - extra_field_selection, + + " must not select an introspection top level field.", + field_nodes, ) ) - for field_nodes in fields.values(): - field = field_nodes[0] - field_name = field.name.value - if field_name.startswith("__"): - self.report_error( - GraphQLError( - ( - "Anonymous Subscription" - if operation_name is None - else f"Subscription '{operation_name}'" - ) - + " must not select an introspection top level field.", - field_nodes, - ) - ) diff --git a/src/graphql/validation/validation_context.py b/src/graphql/validation/validation_context.py index 8ed24ef6..646ab5f3 100644 --- a/src/graphql/validation/validation_context.py +++ b/src/graphql/validation/validation_context.py @@ -97,10 +97,12 @@ def report_error(self, error: GraphQLError) -> None: def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]: fragments = self._fragments if fragments is None: - fragments = {} - for statement in self.document.definitions: - if isinstance(statement, FragmentDefinitionNode): - fragments[statement.name.value] = statement + fragments = { + statement.name.value: statement + for statement in self.document.definitions + if isinstance(statement, FragmentDefinitionNode) + } + self._fragments = fragments return fragments.get(name)