Skip to content

Commit ac3caf6

Browse files
committed
visitor: Speed up ParallelVisitor
Contrary to GraphQL.js, we still support custom node types. However, we use a similar caching mechanism for the visit functions to speed up the visit process, particularly for the ParallelVisitor. Replicates graphql/graphql-js@f4efee9
1 parent 6099c96 commit ac3caf6

File tree

9 files changed

+177
-46
lines changed

9 files changed

+177
-46
lines changed

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@
133133
enum.Enum
134134
traceback
135135
types.TracebackType
136+
EnterLeaveVisitor
136137
FormattedSourceLocation
137138
asyncio.events.AbstractEventLoop
138139
graphql.language.lexer.EscapeSequence
140+
graphql.language.visitor.EnterLeaveVisitor
139141
graphql.subscription.map_async_iterator.MapAsyncIterator
140142
graphql.type.schema.InterfaceImplementations
141143
graphql.validation.validation_context.VariableUsage

src/graphql/language/visitor.py

Lines changed: 97 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ class VisitorActionEnum(Enum):
126126
}
127127

128128

129+
class EnterLeaveVisitor(NamedTuple):
130+
"""Visitor with functions for entering and leaving."""
131+
132+
enter: Optional[Callable[..., Optional[VisitorAction]]]
133+
leave: Optional[Callable[..., Optional[VisitorAction]]]
134+
135+
129136
class Visitor:
130137
"""Visitor that walks through an AST.
131138
@@ -170,6 +177,8 @@ def leave(self, node, key, parent, path, ancestors):
170177
# Provide special return values as attributes
171178
BREAK, SKIP, REMOVE, IDLE = BREAK, SKIP, REMOVE, IDLE
172179

180+
enter_leave_map: Dict[str, EnterLeaveVisitor]
181+
173182
def __init_subclass__(cls) -> None:
174183
"""Verify that all defined handlers are valid."""
175184
super().__init_subclass__()
@@ -191,13 +200,34 @@ def __init_subclass__(cls) -> None:
191200
):
192201
raise TypeError(f"Invalid AST node kind: {kind}.")
193202

194-
def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable:
195-
"""Get the visit function for the given node kind and direction."""
196-
method = "leave" if is_leaving else "enter"
197-
visit_fn = getattr(self, f"{method}_{kind}", None)
198-
if not visit_fn:
199-
visit_fn = getattr(self, method, None)
200-
return visit_fn
203+
def __init__(self) -> None:
204+
self.enter_leave_map = {}
205+
206+
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
207+
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
208+
try:
209+
return self.enter_leave_map[kind]
210+
except KeyError:
211+
enter_fn = getattr(self, f"enter_{kind}", None)
212+
if not enter_fn:
213+
enter_fn = getattr(self, "enter", None)
214+
leave_fn = getattr(self, f"leave_{kind}", None)
215+
if not leave_fn:
216+
leave_fn = getattr(self, "leave", None)
217+
enter_leave = EnterLeaveVisitor(enter_fn, leave_fn)
218+
self.enter_leave_map[kind] = enter_leave
219+
return enter_leave
220+
221+
def get_visit_fn(
222+
self, kind: str, is_leaving: bool = False
223+
) -> Optional[Callable[..., Optional[VisitorAction]]]:
224+
"""Get the visit function for the given node kind and direction.
225+
226+
.. deprecated:: 3.2
227+
Please use ``get_enter_leave_for_kind`` instead. Will be removed in v3.3.
228+
"""
229+
enter_leave = self.get_enter_leave_for_kind(kind)
230+
return enter_leave.leave if is_leaving else enter_leave.enter
201231

202232

203233
class Stack(NamedTuple):
@@ -237,6 +267,7 @@ def visit(
237267
raise TypeError(f"Not an AST Visitor: {inspect(visitor)}.")
238268
if visitor_keys is None:
239269
visitor_keys = QUERY_DOCUMENT_KEYS
270+
240271
stack: Any = None
241272
in_array = isinstance(root, list)
242273
keys: Tuple[Node, ...] = (root,)
@@ -299,7 +330,8 @@ def visit(
299330
else:
300331
if not isinstance(node, Node):
301332
raise TypeError(f"Invalid AST Node: {inspect(node)}.")
302-
visit_fn = visitor.get_visit_fn(node.kind, is_leaving)
333+
enter_leave = visitor.get_enter_leave_for_kind(node.kind)
334+
visit_fn = enter_leave.leave if is_leaving else enter_leave.enter
303335
if visit_fn:
304336
result = visit_fn(node, key, parent, path, ancestors)
305337

@@ -357,39 +389,63 @@ class ParallelVisitor(Visitor):
357389

358390
def __init__(self, visitors: Collection[Visitor]):
359391
"""Create a new visitor from the given list of parallel visitors."""
392+
super().__init__()
360393
self.visitors = visitors
361394
self.skipping: List[Any] = [None] * len(visitors)
362395

363-
def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
364-
skipping = self.skipping
365-
for i, visitor in enumerate(self.visitors):
366-
if not skipping[i]:
367-
fn = visitor.get_visit_fn(node.kind)
368-
if fn:
369-
result = fn(node, *args)
370-
if result is SKIP or result is False:
371-
skipping[i] = node
372-
elif result is BREAK or result is True:
373-
skipping[i] = BREAK
374-
elif result is not None:
375-
return result
376-
return None
377-
378-
def leave(self, node: Node, *args: Any) -> Optional[VisitorAction]:
379-
skipping = self.skipping
380-
for i, visitor in enumerate(self.visitors):
381-
if not skipping[i]:
382-
fn = visitor.get_visit_fn(node.kind, is_leaving=True)
383-
if fn:
384-
result = fn(node, *args)
385-
if result is BREAK or result is True:
386-
skipping[i] = BREAK
387-
elif (
388-
result is not None
389-
and result is not SKIP
390-
and result is not False
391-
):
392-
return result
393-
elif skipping[i] is node:
394-
skipping[i] = None
395-
return None
396+
def get_enter_leave_for_kind(self, kind: str) -> EnterLeaveVisitor:
397+
"""Given a node kind, return the EnterLeaveVisitor for that kind."""
398+
try:
399+
return self.enter_leave_map[kind]
400+
except KeyError:
401+
has_visitor = False
402+
enter_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
403+
leave_list: List[Optional[Callable[..., Optional[VisitorAction]]]] = []
404+
for visitor in self.visitors:
405+
enter, leave = visitor.get_enter_leave_for_kind(kind)
406+
if not has_visitor and (enter or leave):
407+
has_visitor = True
408+
enter_list.append(enter)
409+
leave_list.append(leave)
410+
411+
if has_visitor:
412+
413+
def enter(node: Node, *args: Any) -> Optional[VisitorAction]:
414+
skipping = self.skipping
415+
for i, fn in enumerate(enter_list):
416+
if not skipping[i]:
417+
if fn:
418+
result = fn(node, *args)
419+
if result is SKIP or result is False:
420+
skipping[i] = node
421+
elif result is BREAK or result is True:
422+
skipping[i] = BREAK
423+
elif result is not None:
424+
return result
425+
return None
426+
427+
def leave(node: Node, *args: Any) -> Optional[VisitorAction]:
428+
skipping = self.skipping
429+
for i, fn in enumerate(leave_list):
430+
if not skipping[i]:
431+
if fn:
432+
result = fn(node, *args)
433+
if result is BREAK or result is True:
434+
skipping[i] = BREAK
435+
elif (
436+
result is not None
437+
and result is not SKIP
438+
and result is not False
439+
):
440+
return result
441+
elif skipping[i] is node:
442+
skipping[i] = None
443+
return None
444+
445+
else:
446+
447+
enter = leave = None
448+
449+
enter_leave = EnterLeaveVisitor(enter, leave)
450+
self.enter_leave_map[kind] = enter_leave
451+
return enter_leave

src/graphql/utilities/type_info.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,13 @@ class TypeInfoVisitor(Visitor):
299299
"""A visitor which maintains a provided TypeInfo."""
300300

301301
def __init__(self, type_info: "TypeInfo", visitor: Visitor):
302+
super().__init__()
302303
self.type_info = type_info
303304
self.visitor = visitor
304305

305306
def enter(self, node: Node, *args: Any) -> Any:
306307
self.type_info.enter(node)
307-
fn = self.visitor.get_visit_fn(node.kind)
308+
fn = self.visitor.get_enter_leave_for_kind(node.kind).enter
308309
if fn:
309310
result = fn(node, *args)
310311
if result is not None:
@@ -314,7 +315,7 @@ def enter(self, node: Node, *args: Any) -> Any:
314315
return result
315316

316317
def leave(self, node: Node, *args: Any) -> Any:
317-
fn = self.visitor.get_visit_fn(node.kind, is_leaving=True)
318+
fn = self.visitor.get_enter_leave_for_kind(node.kind).leave
318319
result = fn(node, *args) if fn else None
319320
self.type_info.leave(node)
320321
return result

src/graphql/validation/rules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ASTValidationRule(Visitor):
1717
context: ASTValidationContext
1818

1919
def __init__(self, context: ASTValidationContext):
20+
super().__init__()
2021
self.context = context
2122

2223
def report_error(self, error: GraphQLError) -> None:

src/graphql/validation/validation_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class VariableUsageVisitor(Visitor):
4747
usages: List[VariableUsage]
4848

4949
def __init__(self, type_info: TypeInfo):
50+
super().__init__()
5051
self.usages = []
5152
self._append_usage = self.usages.append
5253
self._type_info = type_info

tests/benchmarks/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Benchmarks are disabled (only executed as tests) by default in setup.cfg.
44
You can enable them with --benchmark-enable if your want to execute them.
55
6-
E.g. in order to execute all the benchmarks with tox using Python 3.7::
6+
E.g. in order to execute all the benchmarks with tox using Python 3.9::
77
8-
tox -e py37 -- -k benchmarks --benchmark-enable
8+
tox -e py39 -- -k benchmarks --benchmark-enable
99
"""

tests/benchmarks/test_visit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from graphql import parse
2+
from graphql.language import visit, Visitor, ParallelVisitor
3+
4+
from ..fixtures import big_schema_sdl # noqa: F401
5+
6+
7+
class TestVisitor(Visitor):
8+
@staticmethod
9+
def enter(*args):
10+
pass
11+
12+
@staticmethod
13+
def leave(*args):
14+
pass
15+
16+
17+
def test_visit_all_ast_nodes(benchmark, big_schema_sdl): # noqa: F811
18+
document_ast = parse(big_schema_sdl)
19+
visitor = TestVisitor()
20+
benchmark(lambda: visit(document_ast, visitor))
21+
22+
23+
def test_visit_all_ast_nodes_in_parallel(benchmark, big_schema_sdl): # noqa: F811
24+
document_ast = parse(big_schema_sdl)
25+
visitor = TestVisitor()
26+
parallel_visitor = ParallelVisitor([visitor] * 50)
27+
benchmark(lambda: visit(document_ast, parallel_visitor))

tests/language/test_visitor.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,41 @@ def leave_field(node, *args):
214214
"leave:document",
215215
]
216216

217+
def has_get_enter_leave_for_kind_method():
218+
class TestVisitor(Visitor):
219+
@staticmethod
220+
def enter(*args):
221+
pass
222+
223+
@staticmethod
224+
def enter_document(*args):
225+
pass
226+
227+
@staticmethod
228+
def leave(*args):
229+
pass
230+
231+
@staticmethod
232+
def leave_document(*args):
233+
pass
234+
235+
visitor = TestVisitor()
236+
237+
assert visitor.get_enter_leave_for_kind("document") == (
238+
visitor.enter_document,
239+
visitor.leave_document,
240+
)
241+
assert visitor.get_enter_leave_for_kind("field") == (
242+
visitor.enter,
243+
visitor.leave,
244+
)
245+
246+
# also test deprecated method
247+
assert visitor.get_visit_fn("document") == visitor.enter_document
248+
assert visitor.get_visit_fn("field") == visitor.enter
249+
assert visitor.get_visit_fn("document", True) == visitor.leave_document
250+
assert visitor.get_visit_fn("field", True) == visitor.leave
251+
217252
def validates_path_argument():
218253
ast = parse("{ a }", no_location=True)
219254
visited = []
@@ -540,7 +575,10 @@ def leave_selection_set(*args):
540575
["leave", "selection_set", None],
541576
]
542577

543-
def visit_nodes_with_unknown_kinds_but_does_not_traverse_deeper():
578+
def visit_nodes_with_custom_kinds_but_does_not_traverse_deeper():
579+
# GraphQL.js removed support for unknown node types,
580+
# but it is easy for us to add and support custom node types,
581+
# so we keep allowing this and test this feature here.
544582
custom_ast = parse("{ a }")
545583

546584
class CustomFieldNode(SelectionNode):
@@ -1129,6 +1167,7 @@ def allows_skipping_different_sub_trees(skip_action):
11291167

11301168
class TestVisitor(Visitor):
11311169
def __init__(self, name):
1170+
super().__init__()
11321171
self.name = name
11331172

11341173
def enter(self, *args):
@@ -1232,6 +1271,7 @@ def allows_early_exit_from_different_points(break_action):
12321271

12331272
class TestVisitor(Visitor):
12341273
def __init__(self, name):
1274+
super().__init__()
12351275
self.name = name
12361276

12371277
def enter(self, *args):
@@ -1323,6 +1363,7 @@ def allows_early_exit_from_leaving_different_points(break_action):
13231363

13241364
class TestVisitor(Visitor):
13251365
def __init__(self, name):
1366+
super().__init__()
13261367
self.name = name
13271368

13281369
def enter(self, *args):

tests/utilities/test_type_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def supports_different_operation_types():
101101

102102
class TestVisitor(Visitor):
103103
def __init__(self):
104+
super().__init__()
104105
self.root_types = {}
105106

106107
def enter_operation_definition(self, node: OperationDefinitionNode, *_args):
@@ -121,6 +122,7 @@ def provide_exact_same_arguments_to_wrapped_visitor():
121122

122123
class TestVisitor(Visitor):
123124
def __init__(self):
125+
super().__init__()
124126
self.args = []
125127

126128
def enter(self, *args):

0 commit comments

Comments
 (0)