Skip to content

chore: Refactor MultiCode Expression ✨ #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions src/graphql/language/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,15 @@ def __init_subclass__(cls) -> None:
kind: Optional[str] = None
else:
attr, kind = attr_kind
if attr in ("enter", "leave"):
if kind:
name = snake_to_camel(kind) + "Node"
node_cls = getattr(ast, name, None)
if (
not node_cls
or not isinstance(node_cls, type)
or not issubclass(node_cls, Node)
):
raise TypeError(f"Invalid AST node kind: {kind}.")
if attr in ("enter", "leave") and kind:
name = snake_to_camel(kind) + "Node"
node_cls = getattr(ast, name, None)
if (
not node_cls
or not isinstance(node_cls, type)
or not issubclass(node_cls, Node)
):
raise TypeError(f"Invalid AST node kind: {kind}.")

def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable:
"""Get the visit function for the given node kind and direction."""
Expand Down Expand Up @@ -256,22 +255,18 @@ def visit(root: Node, visitor: Visitor) -> Any:
node: Any = parent
parent = ancestors_pop() if ancestors else None
if is_edited:
if in_array:
node = node[:]
else:
node = copy(node)
node = node[:] if in_array else copy(node)
edit_offset = 0
for edit_key, edit_value in edits:
if in_array:
edit_key -= edit_offset
if in_array and (edit_value is REMOVE or edit_value is Ellipsis):
node.pop(edit_key)
edit_offset += 1
elif isinstance(node, list):
node[edit_key] = edit_value
else:
if isinstance(node, list):
node[edit_key] = edit_value
else:
setattr(node, edit_key, edit_value)
setattr(node, edit_key, edit_value)

idx = stack.idx
keys = stack.keys
Expand Down
2 changes: 1 addition & 1 deletion src/graphql/pyutils/suggestion_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def measure(self, option: str, threshold: int) -> Optional[int]:
return None

rows = self._rows
for j in range(0, b_len + 1):
for j in range(b_len + 1):
rows[0][j] = j

for i in range(1, a_len + 1):
Expand Down
40 changes: 20 additions & 20 deletions src/graphql/subscription/map_async_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ async def __anext__(self) -> Any:
if not isasyncgen(self.iterator):
raise StopAsyncIteration
value = await self.iterator.__anext__()
result = self.callback(value)

else:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())
Expand All @@ -61,7 +59,8 @@ async def __anext__(self) -> Any:
raise error

value = anext.result()
result = self.callback(value)

result = self.callback(value)

return await result if isawaitable(result) else result

Expand All @@ -72,23 +71,24 @@ async def athrow(
traceback: Optional[TracebackType] = None,
) -> None:
"""Throw an exception into the asynchronous iterator."""
if not self.is_closed:
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_, value, traceback)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_
value = (
type_
if isinstance(value, BaseException)
else cast(Type[BaseException], type_)()
)
if traceback is not None:
value = value.with_traceback(traceback)
raise value
if self.is_closed:
return
athrow = getattr(self.iterator, "athrow", None)
if athrow:
await athrow(type_, value, traceback)
else:
await self.aclose()
if value is None:
if traceback is None:
raise type_
value = (
type_
if isinstance(value, BaseException)
else cast(Type[BaseException], type_)()
)
if traceback is not None:
value = value.with_traceback(traceback)
raise value

async def aclose(self) -> None:
"""Close the iterator."""
Expand Down
5 changes: 1 addition & 4 deletions src/graphql/utilities/print_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,7 @@ def is_schema_of_common_names(schema: GraphQLSchema) -> bool:
return False

subscription_type = schema.subscription_type
if subscription_type and subscription_type.name != "Subscription":
return False

return True
return not subscription_type or subscription_type.name == "Subscription"


def print_type(type_: GraphQLNamedType) -> str:
Expand Down
7 changes: 4 additions & 3 deletions src/graphql/utilities/strip_ignored_characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def strip_ignored_characters(source: Union[str, Source]) -> str:
# Also prevent case of non-punctuator token following by spread resulting
# in invalid token (e.g.`1...` is invalid Float token).
is_non_punctuator = not is_punctuator_token_kind(current_token.kind)
if was_last_added_token_non_punctuator:
if is_non_punctuator or current_token.kind == TokenKind.SPREAD:
stripped_body += " "
if was_last_added_token_non_punctuator and (
is_non_punctuator or current_token.kind == TokenKind.SPREAD
):
stripped_body += " "

token_body = body[current_token.start : current_token.end]
if token_kind == TokenKind.BLOCK_STRING:
Expand Down
3 changes: 1 addition & 2 deletions src/graphql/validation/rules/lone_anonymous_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ def __init__(self, context: ASTValidationContext):

def enter_document(self, node: DocumentNode, *_args: Any) -> None:
self.operation_count = sum(
1
isinstance(definition, OperationDefinitionNode)
for definition in node.definitions
if isinstance(definition, OperationDefinitionNode)
)

def enter_operation_definition(
Expand Down
99 changes: 50 additions & 49 deletions src/graphql/validation/rules/single_field_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,62 +23,63 @@ class SingleFieldSubscriptionsRule(ValidationRule):
def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
if node.operation == OperationType.SUBSCRIPTION:
schema = self.context.schema
subscription_type = schema.subscription_type
if subscription_type:
operation_name = node.name.value if node.name else None
variable_values: Dict[str, Any] = {}
document = self.context.document
fragments: Dict[str, FragmentDefinitionNode] = {
definition.name.value: definition
for definition in document.definitions
if isinstance(definition, FragmentDefinitionNode)
}
fields = collect_fields(
schema,
fragments,
variable_values,
subscription_type,
node.selection_set,
{},
set(),
)
if len(fields) > 1:
field_selection_lists = list(fields.values())
extra_field_selection_lists = field_selection_lists[1:]
extra_field_selection = [
field
for fields in extra_field_selection_lists
for field in (
fields
if isinstance(fields, list)
else [cast(FieldNode, fields)]
if node.operation != OperationType.SUBSCRIPTION:
return
schema = self.context.schema
subscription_type = schema.subscription_type
if subscription_type:
operation_name = node.name.value if node.name else None
variable_values: Dict[str, Any] = {}
document = self.context.document
fragments: Dict[str, FragmentDefinitionNode] = {
definition.name.value: definition
for definition in document.definitions
if isinstance(definition, FragmentDefinitionNode)
}
fields = collect_fields(
schema,
fragments,
variable_values,
subscription_type,
node.selection_set,
{},
set(),
)
if len(fields) > 1:
field_selection_lists = list(fields.values())
extra_field_selection_lists = field_selection_lists[1:]
extra_field_selection = [
field
for fields in extra_field_selection_lists
for field in (
fields
if isinstance(fields, list)
else [cast(FieldNode, fields)]
)
]
self.report_error(
GraphQLError(
(
"Anonymous Subscription"
if operation_name is None
else f"Subscription '{operation_name}'"
)
]
+ " must select only one top level field.",
extra_field_selection,
)
)
for field_nodes in fields.values():
field = field_nodes[0]
field_name = field.name.value
if field_name.startswith("__"):
self.report_error(
GraphQLError(
(
"Anonymous Subscription"
if operation_name is None
else f"Subscription '{operation_name}'"
)
+ " must select only one top level field.",
extra_field_selection,
+ " must not select an introspection top level field.",
field_nodes,
)
)
for field_nodes in fields.values():
field = field_nodes[0]
field_name = field.name.value
if field_name.startswith("__"):
self.report_error(
GraphQLError(
(
"Anonymous Subscription"
if operation_name is None
else f"Subscription '{operation_name}'"
)
+ " must not select an introspection top level field.",
field_nodes,
)
)
10 changes: 6 additions & 4 deletions src/graphql/validation/validation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def report_error(self, error: GraphQLError) -> None:
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
fragments = self._fragments
if fragments is None:
fragments = {}
for statement in self.document.definitions:
if isinstance(statement, FragmentDefinitionNode):
fragments[statement.name.value] = statement
fragments = {
statement.name.value: statement
for statement in self.document.definitions
if isinstance(statement, FragmentDefinitionNode)
}

self._fragments = fragments
return fragments.get(name)

Expand Down