Skip to content

Better variable value coercion and input object containers support #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ language: python
sudo: false
python:
- 2.7
- 3.3
- 3.4
- 3.5
- 3.6
- pypy
before_install:
- |
Expand Down
4 changes: 4 additions & 0 deletions graphql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@

# Asserts a string is a valid GraphQL name.
assert_valid_name,

# Undefined const
Undefined,
)

__all__ = (
Expand Down Expand Up @@ -284,4 +287,5 @@
'type_from_ast',
'value_from_ast',
'get_version',
'Undefined',
)
3 changes: 1 addition & 2 deletions graphql/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..error import GraphQLError
from ..language import ast
from ..pyutils.default_ordered_dict import DefaultOrderedDict
from ..type.definition import Undefined, GraphQLInterfaceType, GraphQLUnionType
from ..type.definition import GraphQLInterfaceType, GraphQLUnionType
from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective
from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef,
TypeNameMetaFieldDef)
Expand Down Expand Up @@ -75,7 +75,6 @@ def get_field_resolver(self, field_resolver):
def get_argument_values(self, field_def, field_ast):
k = field_def, field_ast
result = self.argument_values_cache.get(k)

if not result:
result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments,
self.variable_values)
Expand Down
3 changes: 2 additions & 1 deletion graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from ..error import GraphQLError, GraphQLLocatedError
from ..pyutils.default_ordered_dict import DefaultOrderedDict
from ..pyutils.ordereddict import OrderedDict
from ..utils.undefined import Undefined
from ..type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList,
GraphQLNonNull, GraphQLObjectType, GraphQLScalarType,
GraphQLSchema, GraphQLUnionType)
from .base import (ExecutionContext, ExecutionResult, ResolveInfo, Undefined,
from .base import (ExecutionContext, ExecutionResult, ResolveInfo,
collect_fields, default_resolve_fn, get_field_def,
get_operation_root_type)
from .executors.sync import SyncExecutor
Expand Down
3 changes: 2 additions & 1 deletion graphql/execution/experimental/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from ...pyutils.cached_property import cached_property
from ...pyutils.default_ordered_dict import DefaultOrderedDict
from ...utils.undefined import Undefined
from ...type import (GraphQLInterfaceType, GraphQLList, GraphQLNonNull,
GraphQLObjectType, GraphQLUnionType)
from ..base import ResolveInfo, Undefined, collect_fields, get_field_def
from ..base import ResolveInfo, collect_fields, get_field_def
from ..values import get_argument_values
from ...error import GraphQLError
try:
Expand Down
7 changes: 4 additions & 3 deletions graphql/execution/experimental/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,10 @@ def test_passes_along_null_for_non_nullable_inputs_if_explcitly_set_in_the_query
'''

check(doc, {
'data': {
'fieldWithNonNullableStringInput': None
}
'errors': [{
'message': 'Argument "input" of required type String!" was not provided.'
}],
'data': None
})


Expand Down
33 changes: 31 additions & 2 deletions graphql/execution/tests/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphql.error import GraphQLError, format_error
from graphql.execution import execute
from graphql.language.parser import parse
from graphql.type import (GraphQLArgument, GraphQLField,
from graphql.type import (GraphQLArgument, GraphQLField, GraphQLBoolean,
GraphQLInputObjectField, GraphQLInputObjectType,
GraphQLList, GraphQLNonNull, GraphQLObjectType,
GraphQLScalarType, GraphQLSchema, GraphQLString)
Expand All @@ -18,13 +18,23 @@
parse_literal=lambda v: 'DeserializedValue' if v.value == 'SerializedValue' else None
)


class my_special_dict(dict):
pass


TestInputObject = GraphQLInputObjectType('TestInputObject', OrderedDict([
('a', GraphQLInputObjectField(GraphQLString)),
('b', GraphQLInputObjectField(GraphQLList(GraphQLString))),
('c', GraphQLInputObjectField(GraphQLNonNull(GraphQLString))),
('d', GraphQLInputObjectField(TestComplexScalar))
]))


TestCustomInputObject = GraphQLInputObjectType('TestCustomInputObject', OrderedDict([
('a', GraphQLInputObjectField(GraphQLString)),
]), container_type=my_special_dict)

stringify = lambda obj: json.dumps(obj, sort_keys=True)


Expand All @@ -47,6 +57,10 @@ def input_to_json(obj, args, context, info):
GraphQLString,
args={'input': GraphQLArgument(TestInputObject)},
resolver=input_to_json),
'fieldWithCustomObjectInput': GraphQLField(
GraphQLBoolean,
args={'input': GraphQLArgument(TestCustomInputObject)},
resolver=lambda root, args, context, info: isinstance(args.get('input'), my_special_dict)),
'fieldWithNullableStringInput': GraphQLField(
GraphQLString,
args={'input': GraphQLArgument(GraphQLString)},
Expand Down Expand Up @@ -412,9 +426,24 @@ def test_passes_along_null_for_non_nullable_inputs_if_explcitly_set_in_the_query
}
'''

check(doc, {
'errors': [{
'message': 'Argument "input" of required type String!" was not provided.'
}],
'data': None
})


def test_uses_objectinput_container():
doc = '''
{
fieldWithCustomObjectInput(input: {a: "b"})
}
'''

check(doc, {
'data': {
'fieldWithNonNullableStringInput': None
'fieldWithCustomObjectInput': True
}
})

Expand Down
153 changes: 84 additions & 69 deletions graphql/execution/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from six import string_types

from ..error import GraphQLError
from ..language import ast
from ..language.printer import print_ast
from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList,
GraphQLNonNull, GraphQLScalarType, is_input_type)
Expand All @@ -23,8 +24,43 @@ def get_variable_values(schema, definition_asts, inputs):
values = {}
for def_ast in definition_asts:
var_name = def_ast.variable.name.value
value = get_variable_value(schema, def_ast, inputs.get(var_name))
values[var_name] = value
var_type = type_from_ast(schema, def_ast.type)
value = inputs.get(var_name)

if not is_input_type(var_type):
raise GraphQLError(
'Variable "${var_name}" expected value of type "{var_type}" which cannot be used as an input type.'.format(
var_name=var_name,
var_type=print_ast(def_ast.type),
),
[def_ast]
)
elif value is None:
if def_ast.default_value is not None:
values[var_name] = value_from_ast(def_ast.default_value, var_type)
if isinstance(var_type, GraphQLNonNull):
raise GraphQLError(
'Variable "${var_name}" of required type "{var_type}" was not provided.'.format(
var_name=var_name, var_type=var_type
), [def_ast]
)
else:
errors = is_valid_value(value, var_type)
if errors:
message = u'\n' + u'\n'.join(errors)
raise GraphQLError(
'Variable "${}" got invalid value {}.{}'.format(
var_name,
json.dumps(value, sort_keys=True),
message
),
[def_ast]
)
coerced_value = coerce_value(var_type, value)
if coerced_value is None:
raise Exception('Should have reported error.')

values[var_name] = coerced_value

return values

Expand All @@ -42,72 +78,52 @@ def get_argument_values(arg_defs, arg_asts, variables=None):

result = {}
for name, arg_def in arg_defs.items():
arg_type = arg_def.type
value_ast = arg_ast_map.get(name)
if value_ast:
value_ast = value_ast.value

value = value_from_ast(
value_ast,
arg_def.type,
variables
)
if name not in arg_ast_map:
if arg_def.default_value is not None:
result[arg_def.out_name or name] = arg_def.default_value
continue
elif isinstance(arg_type, GraphQLNonNull):
raise GraphQLError('Argument "{name}" of required type {arg_type}" was not provided.'.format(
name=name,
arg_type=arg_type
), arg_asts)
elif isinstance(value_ast.value, ast.Variable):
variable_name = value_ast.value.name.value
variable_value = variables.get(variable_name)
if variables and variable_name in variables:
result[arg_def.out_name or name] = variable_value
elif arg_def.default_value is not None:
result[arg_def.out_name or name] = arg_def.default_value
elif isinstance(arg_type, GraphQLNonNull):
raise GraphQLError('Argument "{name}" of required type {arg_type}" provided the variable "${variable_name}" which was not provided'.format(
name=name,
arg_type=arg_type,
variable_name=variable_name
), arg_asts)
continue

if value is None:
value = arg_def.default_value
else:
value_ast = value_ast.value

if value is not None:
# We use out_name as the output name for the
# dict if exists
result[arg_def.out_name or name] = value
value = value_from_ast(
value_ast,
arg_type,
variables
)
if value is None:
if arg_def.default_value is not None:
value = arg_def.default_value
result[arg_def.out_name or name] = value
else:
# We use out_name as the output name for the
# dict if exists
result[arg_def.out_name or name] = value

return result


def get_variable_value(schema, definition_ast, input):
"""Given a variable definition, and any value of input, return a value which adheres to the variable definition,
or throw an error."""
type = type_from_ast(schema, definition_ast.type)
variable = definition_ast.variable

if not type or not is_input_type(type):
raise GraphQLError(
'Variable "${}" expected value of type "{}" which cannot be used as an input type.'.format(
variable.name.value,
print_ast(definition_ast.type),
),
[definition_ast]
)

input_type = type
errors = is_valid_value(input, input_type)
if not errors:
if input is None:
default_value = definition_ast.default_value
if default_value:
return value_from_ast(default_value, input_type)

return coerce_value(input_type, input)

if input is None:
raise GraphQLError(
'Variable "${}" of required type "{}" was not provided.'.format(
variable.name.value,
print_ast(definition_ast.type)
),
[definition_ast]
)

message = (u'\n' + u'\n'.join(errors)) if errors else u''
raise GraphQLError(
'Variable "${}" got invalid value {}.{}'.format(
variable.name.value,
json.dumps(input, sort_keys=True),
message
),
[definition_ast]
)


def coerce_value(type, value):
"""Given a type and any value, return a runtime value coerced to match the type."""
if isinstance(type, GraphQLNonNull):
Expand All @@ -130,16 +146,15 @@ def coerce_value(type, value):
fields = type.fields
obj = {}
for field_name, field in fields.items():
field_value = coerce_value(field.type, value.get(field_name))
if field_value is None:
field_value = field.default_value

if field_value is not None:
# We use out_name as the output name for the
# dict if exists
if field_name not in value:
if field.default_value is not None:
field_value = field.default_value
obj[field.out_name or field_name] = field_value
else:
field_value = coerce_value(field.type, value.get(field_name))
obj[field.out_name or field_name] = field_value

return obj
return type.create_container(obj)

assert isinstance(type, (GraphQLScalarType, GraphQLEnumType)), \
'Must be input type'
Expand Down
3 changes: 1 addition & 2 deletions graphql/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
is_leaf_type,
is_type,
get_nullable_type,
is_output_type,
Undefined
is_output_type
)
from .directives import (
# "Enum" of Directive locations
Expand Down
20 changes: 8 additions & 12 deletions graphql/type/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@
from ..utils.assert_valid_name import assert_valid_name


class _Undefined(object):
def __bool__(self):
return False

__nonzero__ = __bool__


Undefined = _Undefined()


def is_type(type):
return isinstance(type, (
GraphQLScalarType,
Expand Down Expand Up @@ -516,13 +506,19 @@ class GeoPoint(GraphQLInputObjectType):
default_value=0)
}
"""
def __init__(self, name, fields, description=None):
def __init__(self, name, fields, description=None, container_type=None):
assert name, 'Type must be named.'
self.name = name
self.description = description

if container_type is None:
container_type = dict
assert callable(container_type), "container_type must be callable"
self.container_type = container_type
self._fields = fields

def create_container(self, data):
return self.container_type(data)

@cached_property
def fields(self):
return self._define_field_map()
Expand Down
Loading