From 24bcd0c583e1e0adb1e91d8f661f0498e6bec77a Mon Sep 17 00:00:00 2001 From: Alex Chamberlain Date: Mon, 9 Nov 2020 17:41:48 +0000 Subject: [PATCH] Improve validation performance Improve the validation performance by memoising the hashes and visitor function lookups. This improves the `test_validate_introspection_query` benchmark --- src/graphql/language/ast.py | 9 ++++-- src/graphql/language/visitor.py | 31 +++++++++++++++----- src/graphql/utilities/type_info.py | 14 +++++++-- src/graphql/validation/rules/__init__.py | 1 + src/graphql/validation/validation_context.py | 1 + 5 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/graphql/language/ast.py b/src/graphql/language/ast.py index 188cfe6a..a17e0b19 100644 --- a/src/graphql/language/ast.py +++ b/src/graphql/language/ast.py @@ -221,15 +221,17 @@ class Node: """AST nodes""" # allow custom attributes and weak references (not used internally) - __slots__ = "__dict__", "__weakref__", "loc" + __slots__ = "__dict__", "__weakref__", "loc", "_hash" loc: Optional[Location] kind: str = "ast" # the kind of the node as a snake_case string keys = ["loc"] # the names of the attributes of this node + def __init__(self, **kwargs: Any) -> None: """Initialize the node with the given keyword arguments.""" + self._hash = None for key in self.keys: value = kwargs.get(key) if isinstance(value, list) and not isinstance(value, FrozenList): @@ -250,7 +252,10 @@ def __eq__(self, other: Any) -> bool: ) def __hash__(self) -> int: - return hash(tuple(getattr(self, key) for key in self.keys)) + if self._hash is None: + self._hash = hash(tuple(getattr(self, key) for key in self.keys)) + + return self._hash def __copy__(self) -> "Node": """Create a shallow copy of the node.""" diff --git a/src/graphql/language/visitor.py b/src/graphql/language/visitor.py index 2378e27b..b715e02e 100644 --- a/src/graphql/language/visitor.py +++ b/src/graphql/language/visitor.py @@ -173,6 +173,9 @@ def leave(self, node, key, parent, path, ancestors): # Provide special return values as attributes BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE + def __init__(self): + self._visit_fns = {} + def __init_subclass__(cls) -> None: """Verify that all defined handlers are valid.""" super().__init_subclass__() @@ -197,11 +200,12 @@ def __init_subclass__(cls) -> None: def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable: """Get the visit function for the given node kind and direction.""" - method = "leave" if is_leaving else "enter" - visit_fn = getattr(self, f"{method}_{kind}", None) - if not visit_fn: - visit_fn = getattr(self, method, None) - return visit_fn + key = (kind, is_leaving) + if key not in self._visit_fns: + method = "leave" if is_leaving else "enter" + fn = getattr(self, f"{method}_{kind}", None) + self._visit_fns[key] = fn or getattr(self, method, None) + return self._visit_fns[key] class Stack(NamedTuple): @@ -367,14 +371,22 @@ class ParallelVisitor(Visitor): def __init__(self, visitors: Collection[Visitor]): """Create a new visitor from the given list of parallel visitors.""" + super().__init__() self.visitors = visitors self.skipping: List[Any] = [None] * len(visitors) + self._enter_visit_fns = {} + self._leave_visit_fns = {} def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]: + visit_fns = self._enter_visit_fns.get(node.kind) + if visit_fns is None: + visit_fns = [v.get_visit_fn(node.kind) for v in self.visitors] + self._enter_visit_fns[node.kind] = visit_fns + skipping = self.skipping for i, visitor in enumerate(self.visitors): if not skipping[i]: - fn = visitor.get_visit_fn(node.kind) + fn = visit_fns[i] if fn: result = fn(node, *args) if result is SKIP or result is False: @@ -386,10 +398,15 @@ def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]: return None def leave(self, node: Node, *args: Any) -> Optional[VisitorAction]: + visit_fns = self._leave_visit_fns.get(node.kind) + if visit_fns is None: + visit_fns = [v.get_visit_fn(node.kind, is_leaving=True) for v in self.visitors] + self._leave_visit_fns[node.kind] = visit_fns + skipping = self.skipping for i, visitor in enumerate(self.visitors): if not skipping[i]: - fn = visitor.get_visit_fn(node.kind, is_leaving=True) + fn = visit_fns[i] if fn: result = fn(node, *args) if result is BREAK or result is True: diff --git a/src/graphql/utilities/type_info.py b/src/graphql/utilities/type_info.py index 847e573d..293fbf26 100644 --- a/src/graphql/utilities/type_info.py +++ b/src/graphql/utilities/type_info.py @@ -88,6 +88,7 @@ def __init__( self._argument: Optional[GraphQLArgument] = None self._enum_value: Optional[GraphQLEnumValue] = None self._get_field_def = get_field_def_fn or get_field_def + self._visit_fns = {} if initial_type: if is_input_type(initial_type): self._input_type_stack.append(cast(GraphQLInputType, initial_type)) @@ -136,15 +137,23 @@ def get_enum_value(self) -> Optional[GraphQLEnumValue]: return self._enum_value def enter(self, node: Node) -> None: - method = getattr(self, "enter_" + node.kind, None) + method = self._get_method("enter", node.kind) if method: method(node) def leave(self, node: Node) -> None: - method = getattr(self, "leave_" + node.kind, None) + method = self._get_method("leave", node.kind) if method: method() + def _get_method(self, direction: str, kind: str) -> Optional[Callable[[], None]]: + key = (direction, kind) + if key not in self._visit_fns: + fn = getattr(self, f"{direction}_{kind}", None) + self._visit_fns[key] = fn + return self._visit_fns[key] + + # noinspection PyUnusedLocal def enter_selection_set(self, node: SelectionSetNode) -> None: named_type = get_named_type(self.get_type()) @@ -301,6 +310,7 @@ class TypeInfoVisitor(Visitor): """A visitor which maintains a provided TypeInfo.""" def __init__(self, type_info: "TypeInfo", visitor: Visitor): + super().__init__() self.type_info = type_info self.visitor = visitor diff --git a/src/graphql/validation/rules/__init__.py b/src/graphql/validation/rules/__init__.py index 54158742..1b0c5d57 100644 --- a/src/graphql/validation/rules/__init__.py +++ b/src/graphql/validation/rules/__init__.py @@ -17,6 +17,7 @@ class ASTValidationRule(Visitor): context: ASTValidationContext def __init__(self, context: ASTValidationContext): + super().__init__() self.context = context def report_error(self, error: GraphQLError) -> None: diff --git a/src/graphql/validation/validation_context.py b/src/graphql/validation/validation_context.py index 1db7b45a..ae690c3c 100644 --- a/src/graphql/validation/validation_context.py +++ b/src/graphql/validation/validation_context.py @@ -46,6 +46,7 @@ class VariableUsageVisitor(Visitor): usages: List[VariableUsage] def __init__(self, type_info: TypeInfo): + super().__init__() self.usages = [] self._append_usage = self.usages.append self._type_info = type_info