From 249c2d9eb4eb2b48b410b73ee7afa0dd67788518 Mon Sep 17 00:00:00 2001 From: Yasser Tahiri Date: Fri, 5 Nov 2021 12:30:19 +0100 Subject: [PATCH 1/3] chore: Refactor MultiCode Expression --- src/graphql/language/block_string.py | 2 +- src/graphql/language/visitor.py | 31 +++--- src/graphql/pyutils/suggestion_list.py | 2 +- .../subscription/map_async_iterator.py | 39 ++++---- src/graphql/utilities/print_schema.py | 5 +- .../utilities/strip_ignored_characters.py | 42 +------- .../rules/lone_anonymous_operation.py | 3 +- .../rules/single_field_subscriptions.py | 99 ++++++++++--------- src/graphql/validation/validation_context.py | 10 +- 9 files changed, 97 insertions(+), 136 deletions(-) diff --git a/src/graphql/language/block_string.py b/src/graphql/language/block_string.py index 3f07479e..a8da18d9 100644 --- a/src/graphql/language/block_string.py +++ b/src/graphql/language/block_string.py @@ -38,7 +38,7 @@ def dedent_block_string_value(raw_string: str) -> str: def is_blank(s: str) -> bool: """Check whether string contains only space or tab characters.""" - return all(c == " " or c == "\t" for c in s) + return all(c in [" ", "\t"] for c in s) def get_block_string_indentation(value: str) -> int: diff --git a/src/graphql/language/visitor.py b/src/graphql/language/visitor.py index ce845949..3cb0f009 100644 --- a/src/graphql/language/visitor.py +++ b/src/graphql/language/visitor.py @@ -183,16 +183,15 @@ def __init_subclass__(cls) -> None: kind: Optional[str] = None else: attr, kind = attr_kind - if attr in ("enter", "leave"): - if kind: - name = snake_to_camel(kind) + "Node" - node_cls = getattr(ast, name, None) - if ( - not node_cls - or not isinstance(node_cls, type) - or not issubclass(node_cls, Node) - ): - raise TypeError(f"Invalid AST node kind: {kind}.") + if attr in ("enter", "leave") and kind: + name = snake_to_camel(kind) + "Node" + node_cls = getattr(ast, name, None) + if ( + not node_cls + or not isinstance(node_cls, type) + or not issubclass(node_cls, Node) + ): + raise TypeError(f"Invalid AST node kind: {kind}.") def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable: """Get the visit function for the given node kind and direction.""" @@ -256,10 +255,7 @@ def visit(root: Node, visitor: Visitor) -> Any: node: Any = parent parent = ancestors_pop() if ancestors else None if is_edited: - if in_array: - node = node[:] - else: - node = copy(node) + node = node[:] if in_array else copy(node) edit_offset = 0 for edit_key, edit_value in edits: if in_array: @@ -267,11 +263,10 @@ def visit(root: Node, visitor: Visitor) -> Any: if in_array and (edit_value is REMOVE or edit_value is Ellipsis): node.pop(edit_key) edit_offset += 1 + elif isinstance(node, list): + node[edit_key] = edit_value else: - if isinstance(node, list): - node[edit_key] = edit_value - else: - setattr(node, edit_key, edit_value) + setattr(node, edit_key, edit_value) idx = stack.idx keys = stack.keys diff --git a/src/graphql/pyutils/suggestion_list.py b/src/graphql/pyutils/suggestion_list.py index de78ec10..0020b3ce 100644 --- a/src/graphql/pyutils/suggestion_list.py +++ b/src/graphql/pyutils/suggestion_list.py @@ -73,7 +73,7 @@ def measure(self, option: str, threshold: int) -> Optional[int]: return None rows = self._rows - for j in range(0, b_len + 1): + for j in range(b_len + 1): rows[0][j] = j for i in range(1, a_len + 1): diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index e24490d9..a86aea11 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -33,8 +33,6 @@ async def __anext__(self) -> Any: if not isasyncgen(self.iterator): raise StopAsyncIteration value = await self.iterator.__anext__() - result = self.callback(value) - else: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) @@ -61,7 +59,7 @@ async def __anext__(self) -> Any: raise error value = anext.result() - result = self.callback(value) + result = self.callback(value) return await result if isawaitable(result) else result @@ -72,23 +70,24 @@ async def athrow( traceback: Optional[TracebackType] = None, ) -> None: """Throw an exception into the asynchronous iterator.""" - if not self.is_closed: - athrow = getattr(self.iterator, "athrow", None) - if athrow: - await athrow(type_, value, traceback) - else: - await self.aclose() - if value is None: - if traceback is None: - raise type_ - value = ( - type_ - if isinstance(value, BaseException) - else cast(Type[BaseException], type_)() - ) - if traceback is not None: - value = value.with_traceback(traceback) - raise value + if self.is_closed: + return + athrow = getattr(self.iterator, "athrow", None) + if athrow: + await athrow(type_, value, traceback) + else: + await self.aclose() + if value is None: + if traceback is None: + raise type_ + value = ( + type_ + if isinstance(value, BaseException) + else cast(Type[BaseException], type_)() + ) + if traceback is not None: + value = value.with_traceback(traceback) + raise value async def aclose(self) -> None: """Close the iterator.""" diff --git a/src/graphql/utilities/print_schema.py b/src/graphql/utilities/print_schema.py index d9c54919..685ab65a 100644 --- a/src/graphql/utilities/print_schema.py +++ b/src/graphql/utilities/print_schema.py @@ -108,10 +108,7 @@ def is_schema_of_common_names(schema: GraphQLSchema) -> bool: return False subscription_type = schema.subscription_type - if subscription_type and subscription_type.name != "Subscription": - return False - - return True + return not subscription_type or subscription_type.name == "Subscription" def print_type(type_: GraphQLNamedType) -> str: diff --git a/src/graphql/utilities/strip_ignored_characters.py b/src/graphql/utilities/strip_ignored_characters.py index af69dd3a..719dab6a 100644 --- a/src/graphql/utilities/strip_ignored_characters.py +++ b/src/graphql/utilities/strip_ignored_characters.py @@ -31,40 +31,7 @@ def strip_ignored_characters(source: Union[str, Source]) -> str: Warning: It is guaranteed that this function will always produce stable results. However, it's not guaranteed that it will stay the same between different releases due to bugfixes or changes in the GraphQL specification. - """ ''' - - Query example:: - - query SomeQuery($foo: String!, $bar: String) { - someField(foo: $foo, bar: $bar) { - a - b { - c - d - } - } - } - - Becomes:: - - query SomeQuery($foo:String!$bar:String){someField(foo:$foo bar:$bar){a b{c d}}} - - SDL example:: - - """ - Type description - """ - type Foo { - """ - Field description - """ - bar: String - } - - Becomes:: - - """Type description""" type Foo{"""Field description""" bar:String} - ''' + """ source = cast(Source, source) if is_source(source) else Source(cast(str, source)) body = source.body @@ -79,9 +46,10 @@ def strip_ignored_characters(source: Union[str, Source]) -> str: # Also prevent case of non-punctuator token following by spread resulting # in invalid token (e.g.`1...` is invalid Float token). is_non_punctuator = not is_punctuator_token_kind(current_token.kind) - if was_last_added_token_non_punctuator: - if is_non_punctuator or current_token.kind == TokenKind.SPREAD: - stripped_body += " " + if was_last_added_token_non_punctuator and ( + is_non_punctuator or current_token.kind == TokenKind.SPREAD + ): + stripped_body += " " token_body = body[current_token.start : current_token.end] if token_kind == TokenKind.BLOCK_STRING: diff --git a/src/graphql/validation/rules/lone_anonymous_operation.py b/src/graphql/validation/rules/lone_anonymous_operation.py index 3e945f8b..f14d5f5a 100644 --- a/src/graphql/validation/rules/lone_anonymous_operation.py +++ b/src/graphql/validation/rules/lone_anonymous_operation.py @@ -20,9 +20,8 @@ def __init__(self, context: ASTValidationContext): def enter_document(self, node: DocumentNode, *_args: Any) -> None: self.operation_count = sum( - 1 + isinstance(definition, OperationDefinitionNode) for definition in node.definitions - if isinstance(definition, OperationDefinitionNode) ) def enter_operation_definition( diff --git a/src/graphql/validation/rules/single_field_subscriptions.py b/src/graphql/validation/rules/single_field_subscriptions.py index 86b27122..d5e61178 100644 --- a/src/graphql/validation/rules/single_field_subscriptions.py +++ b/src/graphql/validation/rules/single_field_subscriptions.py @@ -23,39 +23,55 @@ class SingleFieldSubscriptionsRule(ValidationRule): def enter_operation_definition( self, node: OperationDefinitionNode, *_args: Any ) -> None: - if node.operation == OperationType.SUBSCRIPTION: - schema = self.context.schema - subscription_type = schema.subscription_type - if subscription_type: - operation_name = node.name.value if node.name else None - variable_values: Dict[str, Any] = {} - document = self.context.document - fragments: Dict[str, FragmentDefinitionNode] = { - definition.name.value: definition - for definition in document.definitions - if isinstance(definition, FragmentDefinitionNode) - } - fields = collect_fields( - schema, - fragments, - variable_values, - subscription_type, - node.selection_set, - {}, - set(), - ) - if len(fields) > 1: - field_selection_lists = list(fields.values()) - extra_field_selection_lists = field_selection_lists[1:] - extra_field_selection = [ - field - for fields in extra_field_selection_lists - for field in ( - fields - if isinstance(fields, list) - else [cast(FieldNode, fields)] + if node.operation != OperationType.SUBSCRIPTION: + return + schema = self.context.schema + subscription_type = schema.subscription_type + if subscription_type: + operation_name = node.name.value if node.name else None + variable_values: Dict[str, Any] = {} + document = self.context.document + fragments: Dict[str, FragmentDefinitionNode] = { + definition.name.value: definition + for definition in document.definitions + if isinstance(definition, FragmentDefinitionNode) + } + fields = collect_fields( + schema, + fragments, + variable_values, + subscription_type, + node.selection_set, + {}, + set(), + ) + if len(fields) > 1: + field_selection_lists = list(fields.values()) + extra_field_selection_lists = field_selection_lists[1:] + extra_field_selection = [ + field + for fields in extra_field_selection_lists + for field in ( + fields + if isinstance(fields, list) + else [cast(FieldNode, fields)] + ) + ] + self.report_error( + GraphQLError( + ( + "Anonymous Subscription" + if operation_name is None + else f"Subscription '{operation_name}'" ) - ] + + " must select only one top level field.", + extra_field_selection, + ) + ) + for field_nodes in fields.values(): + field = field_nodes[0] + field_name = field.name.value + if field_name.startswith("__"): self.report_error( GraphQLError( ( @@ -63,22 +79,7 @@ def enter_operation_definition( if operation_name is None else f"Subscription '{operation_name}'" ) - + " must select only one top level field.", - extra_field_selection, + + " must not select an introspection top level field.", + field_nodes, ) ) - for field_nodes in fields.values(): - field = field_nodes[0] - field_name = field.name.value - if field_name.startswith("__"): - self.report_error( - GraphQLError( - ( - "Anonymous Subscription" - if operation_name is None - else f"Subscription '{operation_name}'" - ) - + " must not select an introspection top level field.", - field_nodes, - ) - ) diff --git a/src/graphql/validation/validation_context.py b/src/graphql/validation/validation_context.py index 8ed24ef6..646ab5f3 100644 --- a/src/graphql/validation/validation_context.py +++ b/src/graphql/validation/validation_context.py @@ -97,10 +97,12 @@ def report_error(self, error: GraphQLError) -> None: def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]: fragments = self._fragments if fragments is None: - fragments = {} - for statement in self.document.definitions: - if isinstance(statement, FragmentDefinitionNode): - fragments[statement.name.value] = statement + fragments = { + statement.name.value: statement + for statement in self.document.definitions + if isinstance(statement, FragmentDefinitionNode) + } + self._fragments = fragments return fragments.get(name) From 13b37fbdc011bb25fb6d8e09d1f79c50bc844a94 Mon Sep 17 00:00:00 2001 From: Yasser Tahiri Date: Sun, 7 Nov 2021 02:01:10 +0100 Subject: [PATCH 2/3] Revert some minor changes --- src/graphql/language/block_string.py | 2 +- .../utilities/strip_ignored_characters.py | 35 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/graphql/language/block_string.py b/src/graphql/language/block_string.py index a8da18d9..3f07479e 100644 --- a/src/graphql/language/block_string.py +++ b/src/graphql/language/block_string.py @@ -38,7 +38,7 @@ def dedent_block_string_value(raw_string: str) -> str: def is_blank(s: str) -> bool: """Check whether string contains only space or tab characters.""" - return all(c in [" ", "\t"] for c in s) + return all(c == " " or c == "\t" for c in s) def get_block_string_indentation(value: str) -> int: diff --git a/src/graphql/utilities/strip_ignored_characters.py b/src/graphql/utilities/strip_ignored_characters.py index 719dab6a..860893b2 100644 --- a/src/graphql/utilities/strip_ignored_characters.py +++ b/src/graphql/utilities/strip_ignored_characters.py @@ -31,7 +31,40 @@ def strip_ignored_characters(source: Union[str, Source]) -> str: Warning: It is guaranteed that this function will always produce stable results. However, it's not guaranteed that it will stay the same between different releases due to bugfixes or changes in the GraphQL specification. - """ + """ ''' + + Query example:: + + query SomeQuery($foo: String!, $bar: String) { + someField(foo: $foo, bar: $bar) { + a + b { + c + d + } + } + } + + Becomes:: + + query SomeQuery($foo:String!$bar:String){someField(foo:$foo bar:$bar){a b{c d}}} + + SDL example:: + + """ + Type description + """ + type Foo { + """ + Field description + """ + bar: String + } + + Becomes:: + + """Type description""" type Foo{"""Field description""" bar:String} + ''' source = cast(Source, source) if is_source(source) else Source(cast(str, source)) body = source.body From 35aa9b261c850aa5f0c335c2405956fd41ed5ca2 Mon Sep 17 00:00:00 2001 From: Yasser Tahiri Date: Sun, 7 Nov 2021 02:02:58 +0100 Subject: [PATCH 3/3] add new line --- src/graphql/subscription/map_async_iterator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index a86aea11..43400fd3 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -59,6 +59,7 @@ async def __anext__(self) -> Any: raise error value = anext.result() + result = self.callback(value) return await result if isawaitable(result) else result