Skip to content

validation.validate slowness #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/graphql/language/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,17 @@ class Node:
"""AST nodes"""

# allow custom attributes and weak references (not used internally)
__slots__ = "__dict__", "__weakref__", "loc"
__slots__ = "__dict__", "__weakref__", "loc", "_hash"

loc: Optional[Location]

kind: str = "ast" # the kind of the node as a snake_case string
keys = ["loc"] # the names of the attributes of this node


def __init__(self, **kwargs: Any) -> None:
"""Initialize the node with the given keyword arguments."""
self._hash = None
for key in self.keys:
value = kwargs.get(key)
if isinstance(value, list) and not isinstance(value, FrozenList):
Expand All @@ -250,7 +252,10 @@ def __eq__(self, other: Any) -> bool:
)

def __hash__(self) -> int:
return hash(tuple(getattr(self, key) for key in self.keys))
if self._hash is None:
self._hash = hash(tuple(getattr(self, key) for key in self.keys))

return self._hash

def __copy__(self) -> "Node":
"""Create a shallow copy of the node."""
Expand Down
31 changes: 24 additions & 7 deletions src/graphql/language/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def leave(self, node, key, parent, path, ancestors):
# Provide special return values as attributes
BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE

def __init__(self):
self._visit_fns = {}

def __init_subclass__(cls) -> None:
"""Verify that all defined handlers are valid."""
super().__init_subclass__()
Expand All @@ -197,11 +200,12 @@ def __init_subclass__(cls) -> None:

def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable:
"""Get the visit function for the given node kind and direction."""
method = "leave" if is_leaving else "enter"
visit_fn = getattr(self, f"{method}_{kind}", None)
if not visit_fn:
visit_fn = getattr(self, method, None)
return visit_fn
key = (kind, is_leaving)
if key not in self._visit_fns:
method = "leave" if is_leaving else "enter"
fn = getattr(self, f"{method}_{kind}", None)
self._visit_fns[key] = fn or getattr(self, method, None)
return self._visit_fns[key]


class Stack(NamedTuple):
Expand Down Expand Up @@ -367,14 +371,22 @@ class ParallelVisitor(Visitor):

def __init__(self, visitors: Collection[Visitor]):
"""Create a new visitor from the given list of parallel visitors."""
super().__init__()
self.visitors = visitors
self.skipping: List[Any] = [None] * len(visitors)
self._enter_visit_fns = {}
self._leave_visit_fns = {}

def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
visit_fns = self._enter_visit_fns.get(node.kind)
if visit_fns is None:
visit_fns = [v.get_visit_fn(node.kind) for v in self.visitors]
self._enter_visit_fns[node.kind] = visit_fns

skipping = self.skipping
for i, visitor in enumerate(self.visitors):
if not skipping[i]:
fn = visitor.get_visit_fn(node.kind)
fn = visit_fns[i]
if fn:
result = fn(node, *args)
if result is SKIP or result is False:
Expand All @@ -386,10 +398,15 @@ def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
return None

def leave(self, node: Node, *args: Any) -> Optional[VisitorAction]:
visit_fns = self._leave_visit_fns.get(node.kind)
if visit_fns is None:
visit_fns = [v.get_visit_fn(node.kind, is_leaving=True) for v in self.visitors]
self._leave_visit_fns[node.kind] = visit_fns

skipping = self.skipping
for i, visitor in enumerate(self.visitors):
if not skipping[i]:
fn = visitor.get_visit_fn(node.kind, is_leaving=True)
fn = visit_fns[i]
if fn:
result = fn(node, *args)
if result is BREAK or result is True:
Expand Down
14 changes: 12 additions & 2 deletions src/graphql/utilities/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self._argument: Optional[GraphQLArgument] = None
self._enum_value: Optional[GraphQLEnumValue] = None
self._get_field_def = get_field_def_fn or get_field_def
self._visit_fns = {}
if initial_type:
if is_input_type(initial_type):
self._input_type_stack.append(cast(GraphQLInputType, initial_type))
Expand Down Expand Up @@ -136,15 +137,23 @@ def get_enum_value(self) -> Optional[GraphQLEnumValue]:
return self._enum_value

def enter(self, node: Node) -> None:
method = getattr(self, "enter_" + node.kind, None)
method = self._get_method("enter", node.kind)
if method:
method(node)

def leave(self, node: Node) -> None:
method = getattr(self, "leave_" + node.kind, None)
method = self._get_method("leave", node.kind)
if method:
method()

def _get_method(self, direction: str, kind: str) -> Optional[Callable[[], None]]:
key = (direction, kind)
if key not in self._visit_fns:
fn = getattr(self, f"{direction}_{kind}", None)
self._visit_fns[key] = fn
return self._visit_fns[key]


# noinspection PyUnusedLocal
def enter_selection_set(self, node: SelectionSetNode) -> None:
named_type = get_named_type(self.get_type())
Expand Down Expand Up @@ -301,6 +310,7 @@ class TypeInfoVisitor(Visitor):
"""A visitor which maintains a provided TypeInfo."""

def __init__(self, type_info: "TypeInfo", visitor: Visitor):
super().__init__()
self.type_info = type_info
self.visitor = visitor

Expand Down
1 change: 1 addition & 0 deletions src/graphql/validation/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ASTValidationRule(Visitor):
context: ASTValidationContext

def __init__(self, context: ASTValidationContext):
super().__init__()
self.context = context

def report_error(self, error: GraphQLError) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/graphql/validation/validation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class VariableUsageVisitor(Visitor):
usages: List[VariableUsage]

def __init__(self, type_info: TypeInfo):
super().__init__()
self.usages = []
self._append_usage = self.usages.append
self._type_info = type_info
Expand Down