Skip to content

Commit aa39cc3

Browse files
committed
Refactor base executor
1 parent 2692232 commit aa39cc3

File tree

2 files changed

+307
-294
lines changed

2 files changed

+307
-294
lines changed

graphql/execution/base.py

Lines changed: 12 additions & 294 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,15 @@
1-
# -*- coding: utf-8 -*-
2-
import logging
3-
from traceback import format_exception
4-
5-
from ..error import GraphQLError
6-
from ..language import ast
7-
from ..pyutils.default_ordered_dict import DefaultOrderedDict
8-
from ..type.definition import GraphQLInterfaceType, GraphQLUnionType
9-
from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective
10-
from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef,
11-
TypeNameMetaFieldDef)
12-
from ..utils.type_from_ast import type_from_ast
13-
from .values import get_argument_values, get_variable_values
14-
15-
logger = logging.getLogger(__name__)
16-
17-
18-
class ExecutionContext(object):
19-
"""Data that must be available at all points during query execution.
20-
21-
Namely, schema of the type system that is currently executing,
22-
and the fragments defined in the query document"""
23-
24-
__slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \
25-
'argument_values_cache', 'executor', 'middleware', 'allow_subscriptions', '_subfields_cache'
26-
27-
def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware, allow_subscriptions):
28-
"""Constructs a ExecutionContext object from the arguments passed
29-
to execute, which we will pass throughout the other execution
30-
methods."""
31-
errors = []
32-
operation = None
33-
fragments = {}
34-
35-
for definition in document_ast.definitions:
36-
if isinstance(definition, ast.OperationDefinition):
37-
if not operation_name and operation:
38-
raise GraphQLError(
39-
'Must provide operation name if query contains multiple operations.')
40-
41-
if not operation_name or definition.name and definition.name.value == operation_name:
42-
operation = definition
43-
44-
elif isinstance(definition, ast.FragmentDefinition):
45-
fragments[definition.name.value] = definition
46-
47-
else:
48-
raise GraphQLError(
49-
u'GraphQL cannot execute a request containing a {}.'.format(
50-
definition.__class__.__name__),
51-
definition
52-
)
53-
54-
if not operation:
55-
if operation_name:
56-
raise GraphQLError(
57-
u'Unknown operation named "{}".'.format(operation_name))
58-
59-
else:
60-
raise GraphQLError('Must provide an operation.')
61-
62-
variable_values = get_variable_values(
63-
schema, operation.variable_definitions or [], variable_values)
64-
65-
self.schema = schema
66-
self.fragments = fragments
67-
self.root_value = root_value
68-
self.operation = operation
69-
self.variable_values = variable_values
70-
self.errors = errors
71-
self.context_value = context_value
72-
self.argument_values_cache = {}
73-
self.executor = executor
74-
self.middleware = middleware
75-
self.allow_subscriptions = allow_subscriptions
76-
self._subfields_cache = {}
77-
78-
def get_field_resolver(self, field_resolver):
79-
if not self.middleware:
80-
return field_resolver
81-
return self.middleware.get_field_resolver(field_resolver)
82-
83-
def get_argument_values(self, field_def, field_ast):
84-
k = field_def, field_ast
85-
result = self.argument_values_cache.get(k)
86-
if not result:
87-
result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments,
88-
self.variable_values)
89-
90-
return result
91-
92-
def report_error(self, error, traceback=None):
93-
exception = format_exception(type(error), error, getattr(error, 'stack', None) or traceback)
94-
logger.error(''.join(exception))
95-
self.errors.append(error)
96-
97-
def get_sub_fields(self, return_type, field_asts):
98-
k = return_type, tuple(field_asts)
99-
if k not in self._subfields_cache:
100-
subfield_asts = DefaultOrderedDict(list)
101-
visited_fragment_names = set()
102-
for field_ast in field_asts:
103-
selection_set = field_ast.selection_set
104-
if selection_set:
105-
subfield_asts = collect_fields(
106-
self, return_type, selection_set,
107-
subfield_asts, visited_fragment_names
108-
)
109-
self._subfields_cache[k] = subfield_asts
110-
return self._subfields_cache[k]
111-
112-
113-
class SubscriberExecutionContext(object):
114-
__slots__ = 'exe_context', 'errors'
115-
116-
def __init__(self, exe_context):
117-
self.exe_context = exe_context
118-
self.errors = []
119-
120-
def reset(self):
121-
self.errors = []
122-
123-
def __getattr__(self, name):
124-
return getattr(self.exe_context, name)
1+
# We keep the following imports to preserve compatibility
2+
from .utils import (
3+
ExecutionContext,
4+
SubscriberExecutionContext,
5+
get_operation_root_type,
6+
collect_fields,
7+
should_include_node,
8+
does_fragment_condition_match,
9+
get_field_entry_key,
10+
default_resolve_fn,
11+
get_field_def
12+
)
12513

12614

12715
class ExecutionResult(object):
@@ -152,149 +40,6 @@ def __eq__(self, other):
15240
)
15341

15442

155-
def get_operation_root_type(schema, operation):
156-
op = operation.operation
157-
if op == 'query':
158-
return schema.get_query_type()
159-
160-
elif op == 'mutation':
161-
mutation_type = schema.get_mutation_type()
162-
163-
if not mutation_type:
164-
raise GraphQLError(
165-
'Schema is not configured for mutations',
166-
[operation]
167-
)
168-
169-
return mutation_type
170-
171-
elif op == 'subscription':
172-
subscription_type = schema.get_subscription_type()
173-
174-
if not subscription_type:
175-
raise GraphQLError(
176-
'Schema is not configured for subscriptions',
177-
[operation]
178-
)
179-
180-
return subscription_type
181-
182-
raise GraphQLError(
183-
'Can only execute queries, mutations and subscriptions',
184-
[operation]
185-
)
186-
187-
188-
def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names):
189-
"""
190-
Given a selectionSet, adds all of the fields in that selection to
191-
the passed in map of fields, and returns it at the end.
192-
193-
collect_fields requires the "runtime type" of an object. For a field which
194-
returns and Interface or Union type, the "runtime type" will be the actual
195-
Object type returned by that field.
196-
"""
197-
for selection in selection_set.selections:
198-
directives = selection.directives
199-
200-
if isinstance(selection, ast.Field):
201-
if not should_include_node(ctx, directives):
202-
continue
203-
204-
name = get_field_entry_key(selection)
205-
fields[name].append(selection)
206-
207-
elif isinstance(selection, ast.InlineFragment):
208-
if not should_include_node(
209-
ctx, directives) or not does_fragment_condition_match(
210-
ctx, selection, runtime_type):
211-
continue
212-
213-
collect_fields(ctx, runtime_type,
214-
selection.selection_set, fields, prev_fragment_names)
215-
216-
elif isinstance(selection, ast.FragmentSpread):
217-
frag_name = selection.name.value
218-
219-
if frag_name in prev_fragment_names or not should_include_node(ctx, directives):
220-
continue
221-
222-
prev_fragment_names.add(frag_name)
223-
fragment = ctx.fragments.get(frag_name)
224-
frag_directives = fragment.directives
225-
if not fragment or not \
226-
should_include_node(ctx, frag_directives) or not \
227-
does_fragment_condition_match(ctx, fragment, runtime_type):
228-
continue
229-
230-
collect_fields(ctx, runtime_type,
231-
fragment.selection_set, fields, prev_fragment_names)
232-
233-
return fields
234-
235-
236-
def should_include_node(ctx, directives):
237-
"""Determines if a field should be included based on the @include and
238-
@skip directives, where @skip has higher precidence than @include."""
239-
# TODO: Refactor based on latest code
240-
if directives:
241-
skip_ast = None
242-
243-
for directive in directives:
244-
if directive.name.value == GraphQLSkipDirective.name:
245-
skip_ast = directive
246-
break
247-
248-
if skip_ast:
249-
args = get_argument_values(
250-
GraphQLSkipDirective.args,
251-
skip_ast.arguments,
252-
ctx.variable_values,
253-
)
254-
if args.get('if') is True:
255-
return False
256-
257-
include_ast = None
258-
259-
for directive in directives:
260-
if directive.name.value == GraphQLIncludeDirective.name:
261-
include_ast = directive
262-
break
263-
264-
if include_ast:
265-
args = get_argument_values(
266-
GraphQLIncludeDirective.args,
267-
include_ast.arguments,
268-
ctx.variable_values,
269-
)
270-
271-
if args.get('if') is False:
272-
return False
273-
274-
return True
275-
276-
277-
def does_fragment_condition_match(ctx, fragment, type_):
278-
type_condition_ast = fragment.type_condition
279-
if not type_condition_ast:
280-
return True
281-
282-
conditional_type = type_from_ast(ctx.schema, type_condition_ast)
283-
if conditional_type.is_same_type(type_):
284-
return True
285-
286-
if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)):
287-
return ctx.schema.is_possible_type(conditional_type, type_)
288-
289-
return False
290-
291-
292-
def get_field_entry_key(node):
293-
"""Implements the logic to compute the key of a given field's entry"""
294-
if node.alias:
295-
return node.alias.value
296-
return node.name.value
297-
29843

29944
class ResolveInfo(object):
30045
__slots__ = ('field_name', 'field_asts', 'return_type', 'parent_type',
@@ -313,30 +58,3 @@ def __init__(self, field_name, field_asts, return_type, parent_type,
31358
self.variable_values = variable_values
31459
self.context = context
31560
self.path = path
316-
317-
318-
def default_resolve_fn(source, info, **args):
319-
"""If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object
320-
of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function."""
321-
name = info.field_name
322-
property = getattr(source, name, None)
323-
if callable(property):
324-
return property()
325-
return property
326-
327-
328-
def get_field_def(schema, parent_type, field_name):
329-
"""This method looks up the field on the given type defintion.
330-
It has special casing for the two introspection fields, __schema
331-
and __typename. __typename is special because it can always be
332-
queried as a field, even in situations where no other fields
333-
are allowed, like on a Union. __schema could get automatically
334-
added to the query type, but that would require mutating type
335-
definitions, which would cause issues."""
336-
if field_name == '__schema' and schema.get_query_type() == parent_type:
337-
return SchemaMetaFieldDef
338-
elif field_name == '__type' and schema.get_query_type() == parent_type:
339-
return TypeMetaFieldDef
340-
elif field_name == '__typename':
341-
return TypeNameMetaFieldDef
342-
return parent_type.fields.get(field_name)

0 commit comments

Comments
 (0)