diff --git a/.gitignore b/.gitignore index 04f0af54..74b41897 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] +venv # C extensions *.so @@ -61,6 +62,7 @@ target/ # IntelliJ .idea +*.iml # OS X .DS_Store diff --git a/graphql/__init__.py b/graphql/__init__.py index a2dd03ac..94575a58 100644 --- a/graphql/__init__.py +++ b/graphql/__init__.py @@ -61,6 +61,12 @@ GraphQLInputObjectField, GraphQLArgument, + # "Enum" of Type Kinds + TypeKind, + + # "Enum" of Directive locations + DirectiveLocation, + # Scalars GraphQLInt, GraphQLFloat, @@ -68,6 +74,33 @@ GraphQLBoolean, GraphQLID, + # Directive definition + GraphQLDirective, + + # Built-in directives defined by the Spec + specified_directives, + GraphQLSkipDirective, + GraphQLIncludeDirective, + GraphQLDeprecatedDirective, + + # Constant Deprecation Reason + DEFAULT_DEPRECATION_REASON, + + # GraphQL Types for introspection. + __Schema, + __Directive, + __DirectiveLocation, + __Type, + __Field, + __InputValue, + __EnumValue, + __TypeKind, + + # Meta-field definitions. + SchemaMetaFieldDef, + TypeMetaFieldDef, + TypeNameMetaFieldDef, + # Predicates is_type, is_input_type, @@ -190,6 +223,25 @@ 'GraphQLSchema', 'GraphQLString', 'GraphQLUnionType', + 'GraphQLDirective', + 'specified_directives', + 'GraphQLSkipDirective', + 'GraphQLIncludeDirective', + 'GraphQLDeprecatedDirective', + 'DEFAULT_DEPRECATION_REASON', + 'TypeKind', + 'DirectiveLocation', + '__Schema', + '__Directive', + '__DirectiveLocation', + '__Type', + '__Field', + '__InputValue', + '__EnumValue', + '__TypeKind', + 'SchemaMetaFieldDef', + 'TypeMetaFieldDef', + 'TypeNameMetaFieldDef', 'get_named_type', 'get_nullable_type', 'is_abstract_type', diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 76b3c530..5a01a724 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -321,16 +321,18 @@ def complete_abstract_value(exe_context, return_type, field_asts, info, result): else: runtime_type = get_default_resolve_type_fn(result, exe_context.context_value, info, return_type) - assert isinstance(runtime_type, GraphQLObjectType), ( - 'Abstract type {} must resolve to an Object type at runtime ' + - 'for field {}.{} with value "{}", received "{}".' - ).format( - return_type, - info.parent_type, - info.field_name, - result, - runtime_type, - ) + if not isinstance(runtime_type, GraphQLObjectType): + raise GraphQLError( + ('Abstract type {} must resolve to an Object type at runtime ' + + 'for field {}.{} with value "{}", received "{}".').format( + return_type, + info.parent_type, + info.field_name, + result, + runtime_type, + ), + field_asts + ) if not exe_context.schema.is_possible_type(return_type, runtime_type): raise GraphQLError( diff --git a/graphql/execution/tests/test_executor.py b/graphql/execution/tests/test_executor.py index 233ee09b..9c9db4cd 100644 --- a/graphql/execution/tests/test_executor.py +++ b/graphql/execution/tests/test_executor.py @@ -285,21 +285,6 @@ class Data(object): assert result.data == {'second': 'b'} -def test_uses_the_named_operation_if_operation_name_is_provided(): - doc = 'query Example { first: a } query OtherExample { second: a }' - - class Data(object): - a = 'b' - - ast = parse(doc) - Type = GraphQLObjectType('Type', { - 'a': GraphQLField(GraphQLString) - }) - result = execute(GraphQLSchema(Type), ast, Data(), operation_name='OtherExample') - assert not result.errors - assert result.data == {'second': 'b'} - - def test_raises_if_no_operation_is_provided(): doc = 'fragment Example on Type { a }' diff --git a/graphql/execution/values.py b/graphql/execution/values.py index ba61a192..56008463 100644 --- a/graphql/execution/values.py +++ b/graphql/execution/values.py @@ -29,7 +29,7 @@ def get_variable_values(schema, definition_asts, inputs): return values -def get_argument_values(arg_defs, arg_asts, variables): +def get_argument_values(arg_defs, arg_asts, variables=None): """Prepares an object map of argument values given a list of argument definitions and list of argument AST nodes.""" if not arg_defs: diff --git a/graphql/language/ast.py b/graphql/language/ast.py index 2bae8aa5..6fffae84 100644 --- a/graphql/language/ast.py +++ b/graphql/language/ast.py @@ -845,30 +845,34 @@ class TypeSystemDefinition(TypeDefinition): class SchemaDefinition(TypeSystemDefinition): - __slots__ = ('loc', 'operation_types',) + __slots__ = ('loc', 'directives', 'operation_types',) _fields = ('operation_types',) - def __init__(self, operation_types, loc=None): + def __init__(self, operation_types, loc=None, directives=None): self.operation_types = operation_types self.loc = loc + self.directives = directives def __eq__(self, other): return ( self is other or ( isinstance(other, SchemaDefinition) and - self.operation_types == other.operation_types + self.operation_types == other.operation_types and + self.directives == other.directives ) ) def __repr__(self): return ('SchemaDefinition(' 'operation_types={self.operation_types!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.operation_types, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -911,14 +915,15 @@ def __hash__(self): class ObjectTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name', 'interfaces', 'fields',) + __slots__ = ('loc', 'name', 'interfaces', 'directives', 'fields',) _fields = ('name', 'interfaces', 'fields',) - def __init__(self, name, fields, interfaces=None, loc=None): + def __init__(self, name, fields, interfaces=None, loc=None, directives=None): self.loc = loc self.name = name self.interfaces = interfaces self.fields = fields + self.directives = directives def __eq__(self, other): return ( @@ -927,7 +932,8 @@ def __eq__(self, other): # self.loc == other.loc and self.name == other.name and self.interfaces == other.interfaces and - self.fields == other.fields + self.fields == other.fields and + self.directives == other.directives ) ) @@ -936,6 +942,7 @@ def __repr__(self): 'name={self.name!r}' ', interfaces={self.interfaces!r}' ', fields={self.fields!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): @@ -943,7 +950,8 @@ def __copy__(self): self.name, self.fields, self.interfaces, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -951,14 +959,15 @@ def __hash__(self): class FieldDefinition(Node): - __slots__ = ('loc', 'name', 'arguments', 'type',) + __slots__ = ('loc', 'name', 'arguments', 'type', 'directives',) _fields = ('name', 'arguments', 'type',) - def __init__(self, name, arguments, type, loc=None): + def __init__(self, name, arguments, type, loc=None, directives=None): self.loc = loc self.name = name self.arguments = arguments self.type = type + self.directives = directives def __eq__(self, other): return ( @@ -967,7 +976,8 @@ def __eq__(self, other): # self.loc == other.loc and self.name == other.name and self.arguments == other.arguments and - self.type == other.type + self.type == other.type and + self.directives == other.directives ) ) @@ -983,7 +993,8 @@ def __copy__(self): self.name, self.arguments, self.type, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -991,14 +1002,16 @@ def __hash__(self): class InputValueDefinition(Node): - __slots__ = ('loc', 'name', 'type', 'default_value',) + __slots__ = ('loc', 'name', 'type', 'default_value', 'directives') _fields = ('name', 'type', 'default_value',) - def __init__(self, name, type, default_value=None, loc=None): + def __init__(self, name, type, default_value=None, loc=None, + directives=None): self.loc = loc self.name = name self.type = type self.default_value = default_value + self.directives = directives def __eq__(self, other): return ( @@ -1007,7 +1020,8 @@ def __eq__(self, other): # self.loc == other.loc and self.name == other.name and self.type == other.type and - self.default_value == other.default_value + self.default_value == other.default_value and + self.directives == other.directives ) ) @@ -1016,6 +1030,7 @@ def __repr__(self): 'name={self.name!r}' ', type={self.type!r}' ', default_value={self.default_value!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): @@ -1023,7 +1038,8 @@ def __copy__(self): self.name, self.type, self.default_value, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -1031,13 +1047,14 @@ def __hash__(self): class InterfaceTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name', 'fields',) + __slots__ = ('loc', 'name', 'fields', 'directives',) _fields = ('name', 'fields',) - def __init__(self, name, fields, loc=None): + def __init__(self, name, fields, loc=None, directives=None): self.loc = loc self.name = name self.fields = fields + self.directives = directives def __eq__(self, other): return ( @@ -1045,7 +1062,8 @@ def __eq__(self, other): isinstance(other, InterfaceTypeDefinition) and # self.loc == other.loc and self.name == other.name and - self.fields == other.fields + self.fields == other.fields and + self.directives == other.directives ) ) @@ -1053,13 +1071,15 @@ def __repr__(self): return ('InterfaceTypeDefinition(' 'name={self.name!r}' ', fields={self.fields!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, self.fields, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -1067,13 +1087,14 @@ def __hash__(self): class UnionTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name', 'types',) + __slots__ = ('loc', 'name', 'types', 'directives',) _fields = ('name', 'types',) - def __init__(self, name, types, loc=None): + def __init__(self, name, types, loc=None, directives=None): self.loc = loc self.name = name self.types = types + self.directives = directives def __eq__(self, other): return ( @@ -1081,7 +1102,8 @@ def __eq__(self, other): isinstance(other, UnionTypeDefinition) and # self.loc == other.loc and self.name == other.name and - self.types == other.types + self.types == other.types and + self.directives == other.directives ) ) @@ -1089,13 +1111,15 @@ def __repr__(self): return ('UnionTypeDefinition(' 'name={self.name!r}' ', types={self.types!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, self.types, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -1103,31 +1127,35 @@ def __hash__(self): class ScalarTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name',) + __slots__ = ('loc', 'name', 'directives',) _fields = ('name',) - def __init__(self, name, loc=None): + def __init__(self, name, loc=None, directives=None): self.loc = loc self.name = name + self.directives = directives def __eq__(self, other): return ( self is other or ( isinstance(other, ScalarTypeDefinition) and # self.loc == other.loc and - self.name == other.name + self.name == other.name and + self.directives == other.directives ) ) def __repr__(self): return ('ScalarTypeDefinition(' 'name={self.name!r}' + 'directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, - self.loc + self.loc, + self.directives ) def __hash__(self): @@ -1135,13 +1163,14 @@ def __hash__(self): class EnumTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name', 'values',) + __slots__ = ('loc', 'name', 'values', 'directives',) _fields = ('name', 'values',) - def __init__(self, name, values, loc=None): + def __init__(self, name, values, loc=None, directives=None): self.loc = loc self.name = name self.values = values + self.directives = directives def __eq__(self, other): return ( @@ -1149,7 +1178,8 @@ def __eq__(self, other): isinstance(other, EnumTypeDefinition) and # self.loc == other.loc and self.name == other.name and - self.values == other.values + self.values == other.values and + self.directives == other.directives ) ) @@ -1157,13 +1187,15 @@ def __repr__(self): return ('EnumTypeDefinition(' 'name={self.name!r}' ', values={self.values!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, self.values, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -1171,31 +1203,35 @@ def __hash__(self): class EnumValueDefinition(Node): - __slots__ = ('loc', 'name',) + __slots__ = ('loc', 'name', 'directives',) _fields = ('name',) - def __init__(self, name, loc=None): + def __init__(self, name, loc=None, directives=None): self.loc = loc self.name = name + self.directives = directives def __eq__(self, other): return ( self is other or ( isinstance(other, EnumValueDefinition) and # self.loc == other.loc and - self.name == other.name + self.name == other.name and + self.directives == other.directives ) ) def __repr__(self): return ('EnumValueDefinition(' 'name={self.name!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, - self.loc + self.loc, + self.directives, ) def __hash__(self): @@ -1203,13 +1239,14 @@ def __hash__(self): class InputObjectTypeDefinition(TypeDefinition): - __slots__ = ('loc', 'name', 'fields',) + __slots__ = ('loc', 'name', 'fields', 'directives',) _fields = ('name', 'fields',) - def __init__(self, name, fields, loc=None): + def __init__(self, name, fields, loc=None, directives=None): self.loc = loc self.name = name self.fields = fields + self.directives = directives def __eq__(self, other): return ( @@ -1217,7 +1254,8 @@ def __eq__(self, other): isinstance(other, InputObjectTypeDefinition) and # self.loc == other.loc and self.name == other.name and - self.fields == other.fields + self.fields == other.fields and + self.directives == other.directives ) ) @@ -1225,13 +1263,15 @@ def __repr__(self): return ('InputObjectTypeDefinition(' 'name={self.name!r}' ', fields={self.fields!r}' + ', directives={self.directives!r}' ')').format(self=self) def __copy__(self): return type(self)( self.name, self.fields, - self.loc + self.loc, + self.directives, ) def __hash__(self): diff --git a/graphql/language/parser.py b/graphql/language/parser.py index 87a29b5b..44ba58d8 100644 --- a/graphql/language/parser.py +++ b/graphql/language/parser.py @@ -476,7 +476,6 @@ def parse_directives(parser): directives = [] while peek(parser, TokenKind.AT): directives.append(parse_directive(parser)) - return directives @@ -519,6 +518,21 @@ def parse_named_type(parser): def parse_type_system_definition(parser): + ''' + TypeSystemDefinition : + - SchemaDefinition + - TypeDefinition + - TypeExtensionDefinition + - DirectiveDefinition + + TypeDefinition : + - ScalarTypeDefinition + - ObjectTypeDefinition + - InterfaceTypeDefinition + - UnionTypeDefinition + - EnumTypeDefinition + - InputObjectTypeDefinition + ''' if not peek(parser, TokenKind.NAME): raise unexpected(parser) @@ -557,6 +571,7 @@ def parse_type_system_definition(parser): def parse_schema_definition(parser): start = parser.token.start expect_keyword(parser, 'schema') + directives = parse_directives(parser) operation_types = many( parser, TokenKind.BRACE_L, @@ -565,6 +580,7 @@ def parse_schema_definition(parser): ) return ast.SchemaDefinition( + directives=directives, operation_types=operation_types, loc=loc(parser, start) ) @@ -588,7 +604,8 @@ def parse_scalar_type_definition(parser): return ast.ScalarTypeDefinition( name=parse_name(parser), - loc=loc(parser, start) + directives=parse_directives(parser), + loc=loc(parser, start), ) @@ -598,13 +615,14 @@ def parse_object_type_definition(parser): return ast.ObjectTypeDefinition( name=parse_name(parser), interfaces=parse_implements_interfaces(parser), + directives=parse_directives(parser), fields=any( parser, TokenKind.BRACE_L, parse_field_definition, TokenKind.BRACE_R ), - loc=loc(parser, start) + loc=loc(parser, start), ) @@ -616,7 +634,7 @@ def parse_implements_interfaces(parser): while True: types.append(parse_named_type(parser)) - if peek(parser, TokenKind.BRACE_L): + if not peek(parser, TokenKind.NAME): break return types @@ -629,7 +647,8 @@ def parse_field_definition(parser): name=parse_name(parser), arguments=parse_argument_defs(parser), type=expect(parser, TokenKind.COLON) and parse_type(parser), - loc=loc(parser, start) + directives=parse_directives(parser), + loc=loc(parser, start), ) @@ -643,12 +662,14 @@ def parse_argument_defs(parser): def parse_input_value_def(parser): start = parser.token.start - return ast.InputValueDefinition( + bla = ast.InputValueDefinition( name=parse_name(parser), type=expect(parser, TokenKind.COLON) and parse_type(parser), default_value=parse_const_value(parser) if skip(parser, TokenKind.EQUALS) else None, - loc=loc(parser, start) + directives=parse_directives(parser), + loc=loc(parser, start), ) + return bla def parse_interface_type_definition(parser): @@ -657,8 +678,9 @@ def parse_interface_type_definition(parser): return ast.InterfaceTypeDefinition( name=parse_name(parser), + directives=parse_directives(parser), fields=any(parser, TokenKind.BRACE_L, parse_field_definition, TokenKind.BRACE_R), - loc=loc(parser, start) + loc=loc(parser, start), ) @@ -668,8 +690,9 @@ def parse_union_type_definition(parser): return ast.UnionTypeDefinition( name=parse_name(parser), + directives=parse_directives(parser), types=expect(parser, TokenKind.EQUALS) and parse_union_members(parser), - loc=loc(parser, start) + loc=loc(parser, start), ) @@ -691,8 +714,9 @@ def parse_enum_type_definition(parser): return ast.EnumTypeDefinition( name=parse_name(parser), + directives=parse_directives(parser), values=many(parser, TokenKind.BRACE_L, parse_enum_value_definition, TokenKind.BRACE_R), - loc=loc(parser, start) + loc=loc(parser, start), ) @@ -701,7 +725,8 @@ def parse_enum_value_definition(parser): return ast.EnumValueDefinition( name=parse_name(parser), - loc=loc(parser, start) + directives=parse_directives(parser), + loc=loc(parser, start), ) @@ -711,8 +736,9 @@ def parse_input_object_type_definition(parser): return ast.InputObjectTypeDefinition( name=parse_name(parser), + directives=parse_directives(parser), fields=any(parser, TokenKind.BRACE_L, parse_input_value_def, TokenKind.BRACE_R), - loc=loc(parser, start) + loc=loc(parser, start), ) diff --git a/graphql/language/printer.py b/graphql/language/printer.py index 2f7e1985..a1a6dd36 100644 --- a/graphql/language/printer.py +++ b/graphql/language/printer.py @@ -112,41 +112,53 @@ def leave_NonNullType(self, node, *args): # Type Definitions: def leave_SchemaDefinition(self, node, *args): - return 'schema ' + block(node.operation_types) + return join([ + 'schema', + join(node.directives, ' '), + block(node.operation_types), + ], ' ') def leave_OperationTypeDefinition(self, node, *args): return '{}: {}'.format(node.operation, node.type) def leave_ScalarTypeDefinition(self, node, *args): - return 'scalar ' + node.name + return 'scalar ' + node.name + wrap(' ', join(node.directives, ' ')) def leave_ObjectTypeDefinition(self, node, *args): - return ( - 'type ' + node.name + ' ' + - wrap('implements ', join(node.interfaces, ', '), ' ') + + return join([ + 'type', + node.name, + wrap('implements ', join(node.interfaces, ', ')), + join(node.directives, ' '), block(node.fields) - ) + ], ' ') def leave_FieldDefinition(self, node, *args): - return node.name + wrap('(', join(node.arguments, ', '), ')') + ': ' + node.type + return ( + node.name + + wrap('(', join(node.arguments, ', '), ')') + + ': ' + + node.type + + wrap(' ', join(node.directives, ' ')) + ) def leave_InputValueDefinition(self, node, *args): - return node.name + ': ' + node.type + wrap(' = ', node.default_value) + return node.name + ': ' + node.type + wrap(' = ', node.default_value) + wrap(' ', join(node.directives, ' ')) def leave_InterfaceTypeDefinition(self, node, *args): - return 'interface ' + node.name + ' ' + block(node.fields) + return 'interface ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.fields) def leave_UnionTypeDefinition(self, node, *args): - return 'union ' + node.name + ' = ' + join(node.types, ' | ') + return 'union ' + node.name + wrap(' ', join(node.directives, ' ')) + ' = ' + join(node.types, ' | ') def leave_EnumTypeDefinition(self, node, *args): - return 'enum ' + node.name + ' ' + block(node.values) + return 'enum ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.values) def leave_EnumValueDefinition(self, node, *args): - return node.name + return node.name + wrap(' ', join(node.directives, ' ')) def leave_InputObjectTypeDefinition(self, node, *args): - return 'input ' + node.name + ' ' + block(node.fields) + return 'input ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.fields) def leave_TypeExtensionDefinition(self, node, *args): return 'extend ' + node.definition @@ -162,10 +174,11 @@ def join(maybe_list, separator=''): return '' -def block(maybe_list): - if maybe_list: - return indent('{\n' + join(maybe_list, '\n')) + '\n}' - return '' +def block(_list): + '''Given a list, print each item on its own line, wrapped in an indented "{ }" block.''' + if _list: + return indent('{\n' + join(_list, '\n')) + '\n}' + return '{}' def wrap(start, maybe_str, end=''): diff --git a/graphql/language/tests/fixtures.py b/graphql/language/tests/fixtures.py index 774a8311..b16653c4 100644 --- a/graphql/language/tests/fixtures.py +++ b/graphql/language/tests/fixtures.py @@ -81,29 +81,54 @@ six(argument: InputType = {key: "value"}): Type } +type AnnotatedObject @onObject(arg: "value") { + annotatedField(arg: Type = "default" @onArg): Type @onField +} + interface Bar { one: Type four(argument: String = "string"): String } +interface AnnotatedInterface @onInterface { + annotatedField(arg: Type @onArg): Type @onField +} + union Feed = Story | Article | Advert +union AnnotatedUnion @onUnion = A | B + scalar CustomScalar +scalar AnnotatedScalar @onScalar + enum Site { DESKTOP MOBILE } +enum AnnotatedEnum @onEnum { + ANNOTATED_VALUE @onEnumValue + OTHER_VALUE +} + input InputType { key: String! answer: Int = 42 } +input AnnotatedInput @onInputObjectType { + annotatedField: Type @onField +} + extend type Foo { seven(argument: [String]): Type } +extend type Foo @onType {} + +type NoFields {} + directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT directive @include(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT diff --git a/graphql/language/tests/test_schema_parser.py b/graphql/language/tests/test_schema_parser.py index f9a853a5..59c690c7 100644 --- a/graphql/language/tests/test_schema_parser.py +++ b/graphql/language/tests/test_schema_parser.py @@ -28,6 +28,7 @@ def test_parses_simple_type(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -42,6 +43,7 @@ def test_parses_simple_type(): ), loc=loc(23, 29) ), + directives=[], loc=loc(16, 29) ) ], @@ -50,7 +52,6 @@ def test_parses_simple_type(): ], loc=loc(1, 31) ) - assert doc == expected @@ -71,6 +72,7 @@ def test_parses_simple_extension(): loc=loc(13, 18) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -85,6 +87,7 @@ def test_parses_simple_extension(): ), loc=loc(30, 36) ), + directives=[], loc=loc(23, 36) ) ], @@ -115,6 +118,7 @@ def test_simple_non_null_type(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -132,6 +136,7 @@ def test_simple_non_null_type(): ), loc=loc(23, 30) ), + directives=[], loc=loc(16, 30) ) ], @@ -163,6 +168,7 @@ def test_parses_simple_type_inheriting_interface(): loc=loc(22, 27) ) ], + directives=[], fields=[], loc=loc(0, 31) ) @@ -200,6 +206,7 @@ def test_parses_simple_type_inheriting_multiple_interfaces(): loc=loc(26, 29) ) ], + directives=[], fields=[], loc=loc(0, 33) ) @@ -220,12 +227,14 @@ def test_parses_single_value_enum(): value='Hello', loc=loc(5, 10) ), + directives=[], values=[ ast.EnumValueDefinition( name=ast.Name( value='WORLD', loc=loc(13, 18) ), + directives=[], loc=loc(13, 18) ) ], @@ -249,12 +258,14 @@ def test_parses_double_value_enum(): value='Hello', loc=loc(5, 10) ), + directives=[], values=[ ast.EnumValueDefinition( name=ast.Name( value='WO', loc=loc(13, 15) ), + directives=[], loc=loc(13, 15) ), ast.EnumValueDefinition( @@ -262,6 +273,7 @@ def test_parses_double_value_enum(): value='RLD', loc=loc(17, 20) ), + directives=[], loc=loc(17, 20) ) ], @@ -289,6 +301,7 @@ def test_parses_simple_interface(): value='Hello', loc=loc(11, 16) ), + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -303,6 +316,7 @@ def test_parses_simple_interface(): ), loc=loc(28, 34) ), + directives=[], loc=loc(21, 34) ) ], @@ -330,6 +344,7 @@ def test_parses_simple_field_with_arg(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -350,6 +365,7 @@ def test_parses_simple_field_with_arg(): loc=loc(28, 35) ), default_value=None, + directives=[], loc=loc(22, 35) ) ], @@ -360,6 +376,7 @@ def test_parses_simple_field_with_arg(): ), loc=loc(38, 44) ), + directives=[], loc=loc(16, 44) ) ], @@ -387,6 +404,7 @@ def test_parses_simple_field_with_arg_with_default_value(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -410,6 +428,7 @@ def test_parses_simple_field_with_arg_with_default_value(): value=True, loc=loc(38, 42) ), + directives=[], loc=loc(22, 42) ) ], @@ -420,6 +439,7 @@ def test_parses_simple_field_with_arg_with_default_value(): ), loc=loc(45, 51) ), + directives=[], loc=loc(16, 51) ) ], @@ -447,6 +467,7 @@ def test_parses_simple_field_with_list_arg(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -470,6 +491,7 @@ def test_parses_simple_field_with_list_arg(): loc=loc(30, 38) ), default_value=None, + directives=[], loc=loc(22, 38) ) ], @@ -480,6 +502,7 @@ def test_parses_simple_field_with_list_arg(): ), loc=loc(41, 47) ), + directives=[], loc=loc(16, 47) ) ], @@ -506,6 +529,7 @@ def test_parses_simple_field_with_two_args(): loc=loc(6, 11) ), interfaces=[], + directives=[], fields=[ ast.FieldDefinition( name=ast.Name( @@ -526,6 +550,7 @@ def test_parses_simple_field_with_two_args(): loc=loc(30, 37) ), default_value=None, + directives=[], loc=loc(22, 37) ), ast.InputValueDefinition( @@ -541,6 +566,7 @@ def test_parses_simple_field_with_two_args(): loc=loc(47, 50) ), default_value=None, + directives=[], loc=loc(39, 50) ) ], @@ -551,6 +577,7 @@ def test_parses_simple_field_with_two_args(): ), loc=loc(53, 59) ), + directives=[], loc=loc(16, 59) ) ], @@ -573,6 +600,7 @@ def test_parses_simple_union(): value='Hello', loc=loc(6, 11) ), + directives=[], types=[ ast.NamedType( name=ast.Name( @@ -601,6 +629,7 @@ def test_parses_union_with_two_types(): value='Hello', loc=loc(6, 11) ), + directives=[], types=[ ast.NamedType( name=ast.Name( @@ -636,6 +665,7 @@ def test_parses_scalar(): value='Hello', loc=loc(7, 12) ), + directives=[], loc=loc(0, 12) ) ], @@ -658,6 +688,7 @@ def test_parses_simple_input_object(): value='Hello', loc=loc(7, 12) ), + directives=[], fields=[ ast.InputValueDefinition( name=ast.Name( @@ -672,6 +703,7 @@ def test_parses_simple_input_object(): loc=loc(24, 30) ), default_value=None, + directives=[], loc=loc(17, 30) ) ], diff --git a/graphql/language/tests/test_schema_printer.py b/graphql/language/tests/test_schema_printer.py index abe7887d..e565799e 100644 --- a/graphql/language/tests/test_schema_printer.py +++ b/graphql/language/tests/test_schema_printer.py @@ -50,29 +50,54 @@ def test_prints_kitchen_sink(): six(argument: InputType = {key: "value"}): Type } +type AnnotatedObject @onObject(arg: "value") { + annotatedField(arg: Type = "default" @onArg): Type @onField +} + interface Bar { one: Type four(argument: String = "string"): String } +interface AnnotatedInterface @onInterface { + annotatedField(arg: Type @onArg): Type @onField +} + union Feed = Story | Article | Advert +union AnnotatedUnion @onUnion = A | B + scalar CustomScalar +scalar AnnotatedScalar @onScalar + enum Site { DESKTOP MOBILE } +enum AnnotatedEnum @onEnum { + ANNOTATED_VALUE @onEnumValue + OTHER_VALUE +} + input InputType { key: String! answer: Int = 42 } +input AnnotatedInput @onInputObjectType { + annotatedField: Type @onField +} + extend type Foo { seven(argument: [String]): Type } +extend type Foo @onType {} + +type NoFields {} + directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT directive @include(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT diff --git a/graphql/language/visitor.py b/graphql/language/visitor.py index 5c3578c3..a77c36e9 100644 --- a/graphql/language/visitor.py +++ b/graphql/language/visitor.py @@ -64,7 +64,6 @@ def visit(root, visitor, key_map=None): else: node = copy(node) - edit_offset = 0 for edit_key, edit_value in edits: if in_array: diff --git a/graphql/language/visitor_meta.py b/graphql/language/visitor_meta.py index 1cf080aa..db2e6409 100644 --- a/graphql/language/visitor_meta.py +++ b/graphql/language/visitor_meta.py @@ -30,18 +30,18 @@ ast.ListType: ('type',), ast.NonNullType: ('type',), - ast.SchemaDefinition: ('operation_types',), + ast.SchemaDefinition: ('directives', 'operation_types',), ast.OperationTypeDefinition: ('type',), - ast.ScalarTypeDefinition: ('name',), - ast.ObjectTypeDefinition: ('name', 'interfaces', 'fields'), - ast.FieldDefinition: ('name', 'arguments', 'type'), - ast.InputValueDefinition: ('name', 'type', 'default_value'), - ast.InterfaceTypeDefinition: ('name', 'fields'), - ast.UnionTypeDefinition: ('name', 'types'), - ast.EnumTypeDefinition: ('name', 'values'), - ast.EnumValueDefinition: ('name',), - ast.InputObjectTypeDefinition: ('name', 'fields'), + ast.ScalarTypeDefinition: ('name', 'directives',), + ast.ObjectTypeDefinition: ('name', 'interfaces', 'directives', 'fields'), + ast.FieldDefinition: ('name', 'arguments', 'directives', 'type'), + ast.InputValueDefinition: ('name', 'type', 'directives', 'default_value'), + ast.InterfaceTypeDefinition: ('name', 'directives', 'fields'), + ast.UnionTypeDefinition: ('name', 'directives', 'types'), + ast.EnumTypeDefinition: ('name', 'directives', 'values'), + ast.EnumValueDefinition: ('name', 'directives',), + ast.InputObjectTypeDefinition: ('name', 'directives', 'fields'), ast.TypeExtensionDefinition: ('definition',), diff --git a/graphql/pyutils/pair_set.py b/graphql/pyutils/pair_set.py index 14aa160a..c4a8ca07 100644 --- a/graphql/pyutils/pair_set.py +++ b/graphql/pyutils/pair_set.py @@ -2,19 +2,44 @@ class PairSet(object): __slots__ = '_data', def __init__(self): - self._data = set() + self._data = {} def __contains__(self, item): - return item in self._data + return self.has(item[0], item[1], item[2]) - def has(self, a, b): - return (a, b) in self._data + def __str__(self): + return str(self._data) - def add(self, a, b): - self._data.add((a, b)) - self._data.add((b, a)) + def __repr__(self): + return str(self._data) + + def has(self, a, b, are_mutually_exclusive): + first = self._data.get(a) + print(first) + result = first and first.get(b) + print(result) + if result is None: + return False + + # are_mutually_exclusive being false is a superset of being true, + # hence if we want to know if this PairSet "has" these two with no + # exclusivity, we have to ensure it was added as such. + if not are_mutually_exclusive: + return not result + + return True + + def add(self, a, b, are_mutually_exclusive): + _pair_set_add(self._data, a, b, are_mutually_exclusive) + _pair_set_add(self._data, b, a, are_mutually_exclusive) return self - def remove(self, a, b): - self._data.discard((a, b)) - self._data.discard((b, a)) + +def _pair_set_add(data, a, b, are_mutually_exclusive): + sub_dict = data.get(a) + + if not sub_dict: + sub_dict = {} + data[a] = sub_dict + + sub_dict[b] = are_mutually_exclusive diff --git a/graphql/pyutils/tests/test_pair_set.py b/graphql/pyutils/tests/test_pair_set.py index 9bc3610b..860a209e 100644 --- a/graphql/pyutils/tests/test_pair_set.py +++ b/graphql/pyutils/tests/test_pair_set.py @@ -3,25 +3,39 @@ def test_pair_set(): ps = PairSet() + are_mutually_exclusive = True - ps.add(1, 2) - ps.add(2, 4) + ps.add(1, 2, are_mutually_exclusive) + ps.add(2, 4, are_mutually_exclusive) - assert ps.has(1, 2) - assert ps.has(2, 1) - assert (1, 2) in ps - assert (2, 1) in ps - assert ps.has(4, 2) - assert ps.has(2, 4) + assert ps.has(1, 2, are_mutually_exclusive) + assert ps.has(2, 1, are_mutually_exclusive) + assert not ps.has(1, 2, not are_mutually_exclusive) + assert not ps.has(2, 1, not are_mutually_exclusive) - assert not ps.has(2, 3) - assert not ps.has(1, 3) + assert (1, 2, are_mutually_exclusive) in ps + assert (2, 1, are_mutually_exclusive) in ps + assert (1, 2, (not are_mutually_exclusive)) not in ps + assert (2, 1, (not are_mutually_exclusive)) not in ps - ps.remove(1, 2) - assert not ps.has(1, 2) - assert not ps.has(2, 1) - assert (1, 2) not in ps - assert (2, 1) not in ps + assert ps.has(4, 2, are_mutually_exclusive) + assert ps.has(2, 4, are_mutually_exclusive) - assert ps.has(4, 2) - assert ps.has(2, 4) + assert not ps.has(2, 3, are_mutually_exclusive) + assert not ps.has(1, 3, are_mutually_exclusive) + + assert ps.has(4, 2, are_mutually_exclusive) + assert ps.has(2, 4, are_mutually_exclusive) + + +def test_pair_set_not_mutually_exclusive(): + ps = PairSet() + are_mutually_exclusive = False + + ps.add(1, 2, are_mutually_exclusive) + + assert ps.has(1, 2, are_mutually_exclusive) + assert ps.has(2, 1, are_mutually_exclusive) + + assert ps.has(1, 2, not are_mutually_exclusive) + assert ps.has(2, 1, not are_mutually_exclusive) diff --git a/graphql/type/__init__.py b/graphql/type/__init__.py index 3b71a58a..153c1b5e 100644 --- a/graphql/type/__init__.py +++ b/graphql/type/__init__.py @@ -22,7 +22,20 @@ is_output_type ) from .directives import ( - GraphQLDirective + # "Enum" of Directive locations + DirectiveLocation, + + # Directive definition + GraphQLDirective, + + # Built-in directives defined by the Spec + specified_directives, + GraphQLSkipDirective, + GraphQLIncludeDirective, + GraphQLDeprecatedDirective, + + # Constant Deprecation Reason + DEFAULT_DEPRECATION_REASON, ) from .scalars import ( # no import order GraphQLInt, @@ -32,3 +45,23 @@ GraphQLID, ) from .schema import GraphQLSchema + +from .introspection import ( + # "Enum" of Type Kinds + TypeKind, + + # GraphQL Types for introspection. + __Schema, + __Directive, + __DirectiveLocation, + __Type, + __Field, + __InputValue, + __EnumValue, + __TypeKind, + + # Meta-field definitions. + SchemaMetaFieldDef, + TypeMetaFieldDef, + TypeNameMetaFieldDef +) diff --git a/graphql/type/directives.py b/graphql/type/directives.py index 50a4f279..a09cad89 100644 --- a/graphql/type/directives.py +++ b/graphql/type/directives.py @@ -3,10 +3,11 @@ from ..pyutils.ordereddict import OrderedDict from ..utils.assert_valid_name import assert_valid_name from .definition import GraphQLArgument, GraphQLNonNull, is_input_type -from .scalars import GraphQLBoolean +from .scalars import GraphQLBoolean, GraphQLString class DirectiveLocation(object): + # Operations QUERY = 'QUERY' MUTATION = 'MUTATION' SUBSCRIPTION = 'SUBSCRIPTION' @@ -15,6 +16,19 @@ class DirectiveLocation(object): FRAGMENT_SPREAD = 'FRAGMENT_SPREAD' INLINE_FRAGMENT = 'INLINE_FRAGMENT' + # Schema Definitions + SCHEMA = 'SCHEMA' + SCALAR = 'SCALAR' + OBJECT = 'OBJECT' + FIELD_DEFINITION = 'FIELD_DEFINITION' + ARGUMENT_DEFINITION = 'ARGUMENT_DEFINITION' + INTERFACE = 'INTERFACE' + UNION = 'UNION' + ENUM = 'ENUM' + ENUM_VALUE = 'ENUM_VALUE' + INPUT_OBJECT = 'INPUT_OBJECT' + INPUT_FIELD_DEFINITION = 'INPUT_FIELD_DEFINITION' + OPERATION_LOCATIONS = [ QUERY, MUTATION, @@ -54,9 +68,10 @@ def __init__(self, name, description=None, args=None, locations=None): _arg.type) self.args = args or OrderedDict() - +"""Used to conditionally include fields or fragments.""" GraphQLIncludeDirective = GraphQLDirective( name='include', + description='Directs the executor to include this field or fragment only when the `if` argument is true.', args={ 'if': GraphQLArgument( type=GraphQLNonNull(GraphQLBoolean), @@ -70,8 +85,10 @@ def __init__(self, name, description=None, args=None, locations=None): ] ) +"""Used to conditionally skip (exclude) fields or fragments.""" GraphQLSkipDirective = GraphQLDirective( name='skip', + description='Directs the executor to skip this field or fragment when the `if` argument is true.', args={ 'if': GraphQLArgument( type=GraphQLNonNull(GraphQLBoolean), @@ -84,3 +101,31 @@ def __init__(self, name, description=None, args=None, locations=None): DirectiveLocation.INLINE_FRAGMENT, ] ) + +"""Constant string used for default reason for a deprecation.""" +DEFAULT_DEPRECATION_REASON = 'No longer supported' + +"""Used to declare element of a GraphQL schema as deprecated.""" +GraphQLDeprecatedDirective = GraphQLDirective( + name='deprecated', + description='Marks an element of a GraphQL schema as no longer supported.', + args={ + 'reason': GraphQLArgument( + type=GraphQLString, + description=('Explains why this element was deprecated, usually also including a suggestion for how to' + 'access supported similar data. Formatted in [Markdown]' + '(https://daringfireball.net/projects/markdown/).'), + default_value=DEFAULT_DEPRECATION_REASON + ), + }, + locations=[ + DirectiveLocation.FIELD_DEFINITION, + DirectiveLocation.ENUM_VALUE, + ] +) + +specified_directives = [ + GraphQLIncludeDirective, + GraphQLSkipDirective, + GraphQLDeprecatedDirective +] diff --git a/graphql/type/introspection.py b/graphql/type/introspection.py index b458f3b3..b2732fc8 100644 --- a/graphql/type/introspection.py +++ b/graphql/type/introspection.py @@ -134,6 +134,50 @@ def input_fields_to_list(input_fields): DirectiveLocation.INLINE_FRAGMENT, description='Location adjacent to an inline fragment.' )), + ('SCHEMA', GraphQLEnumValue( + DirectiveLocation.SCHEMA, + description='Location adjacent to a schema definition.' + )), + ('SCALAR', GraphQLEnumValue( + DirectiveLocation.SCALAR, + description='Location adjacent to a scalar definition.' + )), + ('OBJECT', GraphQLEnumValue( + DirectiveLocation.OBJECT, + description='Location adjacent to an object definition.' + )), + ('FIELD_DEFINITION', GraphQLEnumValue( + DirectiveLocation.FIELD_DEFINITION, + description='Location adjacent to a field definition.' + )), + ('ARGUMENT_DEFINITION', GraphQLEnumValue( + DirectiveLocation.ARGUMENT_DEFINITION, + description='Location adjacent to an argument definition.' + )), + ('INTERFACE', GraphQLEnumValue( + DirectiveLocation.INTERFACE, + description='Location adjacent to an interface definition.' + )), + ('UNION', GraphQLEnumValue( + DirectiveLocation.UNION, + description='Location adjacent to a union definition.' + )), + ('ENUM', GraphQLEnumValue( + DirectiveLocation.ENUM, + description='Location adjacent to an enum definition.' + )), + ('ENUM_VALUE', GraphQLEnumValue( + DirectiveLocation.ENUM_VALUE, + description='Location adjacent to an enum value definition.' + )), + ('INPUT_OBJECT', GraphQLEnumValue( + DirectiveLocation.INPUT_OBJECT, + description='Location adjacent to an input object definition.' + )), + ('INPUT_FIELD_DEFINITION', GraphQLEnumValue( + DirectiveLocation.INPUT_FIELD_DEFINITION, + description='Location adjacent to an input object field definition.' + )), ])) diff --git a/graphql/type/schema.py b/graphql/type/schema.py index 41082ae7..088183fa 100644 --- a/graphql/type/schema.py +++ b/graphql/type/schema.py @@ -1,8 +1,7 @@ from collections import Iterable from .definition import GraphQLObjectType -from .directives import (GraphQLDirective, GraphQLIncludeDirective, - GraphQLSkipDirective) +from .directives import GraphQLDirective, specified_directives from .introspection import IntrospectionSchema from .typemap import GraphQLTypeMap @@ -17,8 +16,19 @@ class GraphQLSchema(object): MyAppSchema = GraphQLSchema( query=MyAppQueryRootType, - mutation=MyAppMutationRootType + mutation=MyAppMutationRootType, ) + + Note: If an array of `directives` are provided to GraphQLSchema, that will be + the exact list of directives represented and allowed. If `directives` is not + provided then a default set of the specified directives (e.g. @include and + @skip) will be used. If you wish to provide *additional* directives to these + specified directives, you must explicitly declare them. Example: + + MyAppSchema = GraphQLSchema( + ... + directives=specified_directives.extend([MyCustomerDirective]), + ) """ __slots__ = '_query', '_mutation', '_subscription', '_type_map', '_directives', '_implementations', '_possible_type_map' @@ -40,10 +50,7 @@ def __init__(self, query, mutation=None, subscription=None, directives=None, typ self._mutation = mutation self._subscription = subscription if directives is None: - directives = [ - GraphQLIncludeDirective, - GraphQLSkipDirective - ] + directives = specified_directives assert all(isinstance(d, GraphQLDirective) for d in directives), \ 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( diff --git a/graphql/type/tests/test_schema.py b/graphql/type/tests/test_schema.py new file mode 100644 index 00000000..ff592222 --- /dev/null +++ b/graphql/type/tests/test_schema.py @@ -0,0 +1,39 @@ +from pytest import raises + +from ...type import (GraphQLField, GraphQLInterfaceType, GraphQLObjectType, + GraphQLSchema, GraphQLString) + +interface_type = GraphQLInterfaceType( + name='Interface', + fields={ + 'field_name': GraphQLField( + type=GraphQLString, + resolver=lambda *_: implementing_type + ) + } +) + +implementing_type = GraphQLObjectType( + name='Object', + interfaces=[interface_type], + fields={ + 'field_name': GraphQLField(type=GraphQLString, resolver=lambda *_: '') + } +) + + +schema = GraphQLSchema( + query=GraphQLObjectType( + name='Query', + fields={ + 'get_object': GraphQLField(type=interface_type, resolver=lambda *_: {}) + } + ) +) + +def test_throws_human_readable_error_if_schematypes_not_defined(): + with raises(AssertionError) as exci: + schema.is_possible_type(interface_type, implementing_type) + + assert str(exci.value) == ('Could not find possible implementing types for $Interface in schema. Check that ' + 'schema.types is defined and is an array ofall possible types in the schema.') diff --git a/graphql/type/typemap.py b/graphql/type/typemap.py index 1067ae4d..577a111f 100644 --- a/graphql/type/typemap.py +++ b/graphql/type/typemap.py @@ -1,4 +1,4 @@ -from collections import OrderedDict, defaultdict +from collections import OrderedDict, Sequence, defaultdict from functools import reduce from ..utils.type_comparators import is_equal_type, is_type_sub_type_of @@ -17,11 +17,11 @@ def __init__(self, types): self._possible_type_map = defaultdict(set) # Keep track of all implementations by interface name. - self._implementations = defaultdict(list) - for type in self.values(): - if isinstance(type, GraphQLObjectType): - for interface in type.interfaces: - self._implementations[interface.name].append(type) + self._implementations = {} + for gql_type in self.values(): + if isinstance(gql_type, GraphQLObjectType): + for interface in gql_type.interfaces: + self._implementations.setdefault(interface.name, []).append(gql_type) # Enforce correct interface implementations. for type in self.values(): @@ -33,11 +33,17 @@ def get_possible_types(self, abstract_type): if isinstance(abstract_type, GraphQLUnionType): return abstract_type.types assert isinstance(abstract_type, GraphQLInterfaceType) - return self._implementations[abstract_type.name] + return self._implementations.get(abstract_type.name, None) def is_possible_type(self, abstract_type, possible_type): + possible_types = self.get_possible_types(abstract_type) + assert isinstance(possible_types, Sequence), ( + 'Could not find possible implementing types for ${} in ' + + 'schema. Check that schema.types is defined and is an array of' + + 'all possible types in the schema.' + ).format(abstract_type) + if not self._possible_type_map[abstract_type.name]: - possible_types = self.get_possible_types(abstract_type) self._possible_type_map[abstract_type.name].update([p.name for p in possible_types]) return possible_type.name in self._possible_type_map[abstract_type.name] diff --git a/graphql/utils/build_ast_schema.py b/graphql/utils/build_ast_schema.py index 1486cd24..e0c632b2 100644 --- a/graphql/utils/build_ast_schema.py +++ b/graphql/utils/build_ast_schema.py @@ -1,11 +1,14 @@ +from ..execution.values import get_argument_values from ..language import ast from ..pyutils.ordereddict import OrderedDict -from ..type import (GraphQLArgument, GraphQLBoolean, GraphQLDirective, +from ..type import (GraphQLArgument, GraphQLBoolean, + GraphQLDeprecatedDirective, GraphQLDirective, GraphQLEnumType, GraphQLEnumValue, GraphQLField, - GraphQLFloat, GraphQLID, GraphQLInputObjectField, - GraphQLInputObjectType, GraphQLInt, GraphQLInterfaceType, - GraphQLList, GraphQLNonNull, GraphQLObjectType, - GraphQLScalarType, GraphQLSchema, GraphQLString, + GraphQLFloat, GraphQLID, GraphQLIncludeDirective, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLSkipDirective, GraphQLString, GraphQLUnionType) from ..type.introspection import (__Directive, __DirectiveLocation, __EnumValue, __Field, __InputValue, __Schema, @@ -38,10 +41,12 @@ def _get_named_type_ast(type_ast): return named_type -def _false(*_): return False +def _false(*_): + return False -def _none(*_): return None +def _none(*_): + return None def build_ast_schema(document): @@ -175,7 +180,8 @@ def make_field_def_map(definition): return OrderedDict( (f.name.value, GraphQLField( type=produce_type_def(f.type), - args=make_input_values(f.arguments, GraphQLArgument) + args=make_input_values(f.arguments, GraphQLArgument), + deprecation_reason=get_deprecation_reason(f.directives), )) for f in definition.fields ) @@ -200,11 +206,11 @@ def make_interface_def(definition): ) def make_enum_def(definition): + values = OrderedDict((v.name.value, GraphQLEnumValue(deprecation_reason=get_deprecation_reason(v.directives))) + for v in definition.values) return GraphQLEnumType( name=definition.name.value, - values=OrderedDict( - (v.name.value, GraphQLEnumValue()) for v in definition.values - ) + values=values ) def make_union_def(definition): @@ -242,6 +248,20 @@ def make_input_object_def(definition): types = [type_def_named(definition.name.value) for definition in type_defs] directives = [get_directive(d) for d in directive_defs] + # If specified directive were not explicitly declared, add them. + find_skip_directive = (directive.name for directive in directives if directive.name == 'skip') + find_include_directive = (directive.name for directive in directives if directive.name == 'include') + find_deprecated_directive = (directive.name for directive in directives if directive.name == 'deprecated') + + if not next(find_skip_directive, None): + directives.append(GraphQLSkipDirective) + + if not next(find_include_directive, None): + directives.append(GraphQLIncludeDirective) + + if not next(find_deprecated_directive, None): + directives.append(GraphQLDeprecatedDirective) + schema_kwargs = {'query': get_object_type(ast_map[query_type_name])} if mutation_type_name: @@ -257,3 +277,15 @@ def make_input_object_def(definition): schema_kwargs['types'] = types return GraphQLSchema(**schema_kwargs) + + +def get_deprecation_reason(directives): + deprecated_ast = next((directive for directive in directives + if directive.name.value == GraphQLDeprecatedDirective.name), + None) + + if deprecated_ast: + args = get_argument_values(GraphQLDeprecatedDirective.args, deprecated_ast.arguments) + return args['reason'] + else: + return None diff --git a/graphql/utils/quoted_or_list.py b/graphql/utils/quoted_or_list.py new file mode 100644 index 00000000..9f98bcd8 --- /dev/null +++ b/graphql/utils/quoted_or_list.py @@ -0,0 +1,21 @@ +import functools + +MAX_LENGTH = 5 + + +def quoted_or_list(items): + '''Given [ A, B, C ] return '"A", "B" or "C"'.''' + selected = items[:MAX_LENGTH] + quoted_items = ('"{}"'.format(t) for t in selected) + + def quoted_or_text(text, quoted_and_index): + index = quoted_and_index[0] + quoted_item = quoted_and_index[1] + text += ((', ' if len(selected) > 2 and not index == len(selected) - 1 else ' ') + + ('or ' if index == len(selected) - 1 else '') + + quoted_item) + return text + + enumerated_items = enumerate(quoted_items) + first_item = next(enumerated_items)[1] + return functools.reduce(quoted_or_text, enumerated_items, first_item) diff --git a/graphql/utils/schema_printer.py b/graphql/utils/schema_printer.py index 87efe276..168a17ec 100644 --- a/graphql/utils/schema_printer.py +++ b/graphql/utils/schema_printer.py @@ -2,6 +2,7 @@ from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, GraphQLInterfaceType, GraphQLObjectType, GraphQLScalarType, GraphQLUnionType) +from ..type.directives import DEFAULT_DEPRECATION_REASON from .ast_from_value import ast_from_value @@ -14,7 +15,7 @@ def print_introspection_schema(schema): def is_spec_directive(directive_name): - return directive_name in ('skip', 'include') + return directive_name in ('skip', 'include', 'deprecated') def _is_defined_type(typename): @@ -117,7 +118,7 @@ def _print_enum(type): 'enum {} {{\n' '{}\n' '}}' - ).format(type.name, '\n'.join(' ' + v.name for v in type.values)) + ).format(type.name, '\n'.join(' ' + v.name + _print_deprecated(v) for v in type.values)) def _print_input_object(type): @@ -129,7 +130,19 @@ def _print_input_object(type): def _print_fields(type): - return '\n'.join(' {}{}: {}'.format(f_name, _print_args(f), f.type) for f_name, f in type.fields.items()) + return '\n'.join(' {}{}: {}{}'.format(f_name, _print_args(f), f.type, _print_deprecated(f)) + for f_name, f in type.fields.items()) + + +def _print_deprecated(field_or_enum_value): + reason = field_or_enum_value.deprecation_reason + + if reason is None: + return '' + elif reason in ('', DEFAULT_DEPRECATION_REASON): + return ' @deprecated' + else: + return ' @deprecated(reason: {})'.format(print_ast(ast_from_value(reason))) def _print_args(field_or_directives): diff --git a/graphql/utils/suggestion_list.py b/graphql/utils/suggestion_list.py new file mode 100644 index 00000000..208f8e31 --- /dev/null +++ b/graphql/utils/suggestion_list.py @@ -0,0 +1,56 @@ +from collections import OrderedDict + + +def suggestion_list(inp, options): + ''' + Given an invalid input string and a list of valid options, returns a filtered + list of valid options sorted based on their similarity with the input. + ''' + options_by_distance = OrderedDict() + input_threshold = len(inp) / 2 + + for option in options: + distance = lexical_distance(inp, option) + threshold = max(input_threshold, len(option) / 2, 1) + if distance <= threshold: + options_by_distance[option] = distance + + return sorted(list(options_by_distance.keys()), key=lambda k: options_by_distance[k]) + + +def lexical_distance(a, b): + ''' + Computes the lexical distance between strings A and B. + The "distance" between two strings is given by counting the minimum number + of edits needed to transform string A into string B. An edit can be an + insertion, deletion, or substitution of a single character, or a swap of two + adjacent characters. + This distance can be useful for detecting typos in input or sorting + @returns distance in number of edits + ''' + + d = [[i] for i in range(len(a) + 1)] or [] + d_len = len(d) or 1 + for i in range(d_len): + for j in range(1, len(b) + 1): + if i == 0: + d[i].append(j) + else: + d[i].append(0) + + for i in range(1, len(a) + 1): + for j in range(1, len(b) + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + + d[i][j] = min( + d[i - 1][j] + 1, + d[i][j - 1] + 1, + d[i - 1][j - 1] + cost + ) + + if (i > 1 and j < 1 and + a[i - 1] == b[j - 2] and + a[i - 2] == b[j - 1]): + d[i][j] = min(d[i][j], d[i - 2][j - 2] + cost) + + return d[len(a)][len(b)] diff --git a/graphql/utils/tests/test_build_ast_schema.py b/graphql/utils/tests/test_build_ast_schema.py index 5cb45ccd..51ae1145 100644 --- a/graphql/utils/tests/test_build_ast_schema.py +++ b/graphql/utils/tests/test_build_ast_schema.py @@ -4,8 +4,14 @@ from graphql.utils.build_ast_schema import build_ast_schema from graphql.utils.schema_printer import print_schema +from ...type import (GraphQLDeprecatedDirective, GraphQLIncludeDirective, + GraphQLSkipDirective) + def cycle_output(body): + """This function does a full cycle of going from a string with the contents of the DSL, + parsed in a schema AST, materializing that schema AST into an in-memory GraphQLSchema, + and then finally printing that GraphQL into the DSL""" ast = parse(body) schema = build_ast_schema(ast) return '\n' + print_schema(schema) @@ -45,6 +51,111 @@ def test_with_directives(): assert output == body +def test_maintains_skip_and_include_directives(): + body = ''' + schema { + query: Hello + } + + type Hello { + str: String + } + ''' + + schema = build_ast_schema(parse(body)) + assert len(schema.get_directives()) == 3 + assert schema.get_directive('skip') == GraphQLSkipDirective + assert schema.get_directive('include') == GraphQLIncludeDirective + assert schema.get_directive('deprecated') == GraphQLDeprecatedDirective + + +def test_overriding_directives_excludes_specified(): + body = ''' + schema { + query: Hello + } + + directive @skip on FIELD + directive @include on FIELD + directive @deprecated on FIELD_DEFINITION + + type Hello { + str: String + } + ''' + + schema = build_ast_schema(parse(body)) + assert len(schema.get_directives()) == 3 + assert schema.get_directive('skip') != GraphQLSkipDirective + assert schema.get_directive('skip') is not None + assert schema.get_directive('include') != GraphQLIncludeDirective + assert schema.get_directive('include') is not None + assert schema.get_directive('deprecated') != GraphQLDeprecatedDirective + assert schema.get_directive('deprecated') is not None + + +def test_overriding_skip_directive_excludes_built_in_one(): + body = ''' + schema { + query: Hello + } + + directive @skip on FIELD + + type Hello { + str: String + } + ''' + + schema = build_ast_schema(parse(body)) + assert len(schema.get_directives()) == 3 + assert schema.get_directive('skip') != GraphQLSkipDirective + assert schema.get_directive('skip') is not None + assert schema.get_directive('include') == GraphQLIncludeDirective + assert schema.get_directive('deprecated') == GraphQLDeprecatedDirective + + +def test_overriding_include_directive_excludes_built_in_one(): + body = ''' + schema { + query: Hello + } + + directive @include on FIELD + + type Hello { + str: String + } + ''' + + schema = build_ast_schema(parse(body)) + assert len(schema.get_directives()) == 3 + assert schema.get_directive('skip') == GraphQLSkipDirective + assert schema.get_directive('deprecated') == GraphQLDeprecatedDirective + assert schema.get_directive('include') != GraphQLIncludeDirective + assert schema.get_directive('include') is not None + + +def test_adding_directives_maintains_skip_and_include_directives(): + body = ''' + schema { + query: Hello + } + + directive @foo(arg: Int) on FIELD + + type Hello { + str: String + } + ''' + + schema = build_ast_schema(parse(body)) + assert len(schema.get_directives()) == 4 + assert schema.get_directive('skip') == GraphQLSkipDirective + assert schema.get_directive('include') == GraphQLIncludeDirective + assert schema.get_directive('deprecated') == GraphQLDeprecatedDirective + + def test_type_modifiers(): body = ''' schema { @@ -379,6 +490,29 @@ def test_unreferenced_type_implementing_referenced_union(): assert output == body +def test_supports_deprecated_directive(): + body = ''' +schema { + query: Query +} + +enum MyEnum { + VALUE + OLD_VALUE @deprecated + OTHER_VALUE @deprecated(reason: "Terrible reasons") +} + +type Query { + field1: String @deprecated + field2: Int @deprecated(reason: "Because I said so") + enum: MyEnum +} +''' + + output = cycle_output(body) + assert output == body + + def test_requires_a_schema_definition(): body = ''' type Hello { diff --git a/graphql/utils/tests/test_build_client_schema.py b/graphql/utils/tests/test_build_client_schema.py index 3ed62838..c2248cd7 100644 --- a/graphql/utils/tests/test_build_client_schema.py +++ b/graphql/utils/tests/test_build_client_schema.py @@ -609,6 +609,26 @@ def test_throws_when_missing_kind(): 'introspection query is used in order to build a client schema.' +def test_succeds_on_smaller_equals_than_7_deep_lists(): + schema = GraphQLSchema( + query=GraphQLObjectType( + name='Query', + fields={ + 'foo': GraphQLField( + GraphQLNonNull(GraphQLList( + GraphQLNonNull(GraphQLList(GraphQLNonNull( + GraphQLList(GraphQLNonNull(GraphQLString)) + )) + ))) + ) + } + ) + ) + + introspection = graphql(schema, introspection_query) + build_client_schema(introspection.data) + + def test_fails_on_very_deep_lists(): schema = GraphQLSchema( query=GraphQLObjectType( diff --git a/graphql/utils/tests/test_quoted_or_list.py b/graphql/utils/tests/test_quoted_or_list.py new file mode 100644 index 00000000..7ac13fbd --- /dev/null +++ b/graphql/utils/tests/test_quoted_or_list.py @@ -0,0 +1,20 @@ +from pytest import raises + +from ..quoted_or_list import quoted_or_list + + +def test_does_not_accept_an_empty_list(): + with raises(StopIteration): + quoted_or_list([]) + +def test_returns_single_quoted_item(): + assert quoted_or_list(['A']) == '"A"' + +def test_returns_two_item_list(): + assert quoted_or_list(['A', 'B']) == '"A" or "B"' + +def test_returns_comma_separated_many_item_list(): + assert quoted_or_list(['A', 'B', 'C']) == '"A", "B" or "C"' + +def test_limits_to_five_items(): + assert quoted_or_list(['A', 'B', 'C', 'D', 'E', 'F']) == '"A", "B", "C", "D" or "E"' diff --git a/graphql/utils/tests/test_schema_printer.py b/graphql/utils/tests/test_schema_printer.py index 8a2a5289..f5d4a11c 100644 --- a/graphql/utils/tests/test_schema_printer.py +++ b/graphql/utils/tests/test_schema_printer.py @@ -521,7 +521,7 @@ def test_print_enum(): ''' -def test_prints_introspection_schema(): +def test_print_introspection_schema(): Root = GraphQLObjectType( name='Root', fields={ @@ -541,14 +541,16 @@ def test_prints_introspection_schema(): directive @skip(if: Boolean!) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT +directive @deprecated(reason: String = "No longer supported") on FIELD_DEFINITION | ENUM_VALUE + type __Directive { name: String! description: String locations: [__DirectiveLocation!]! args: [__InputValue!]! - onOperation: Boolean! - onFragment: Boolean! - onField: Boolean! + onOperation: Boolean! @deprecated(reason: "Use `locations`.") + onFragment: Boolean! @deprecated(reason: "Use `locations`.") + onField: Boolean! @deprecated(reason: "Use `locations`.") } enum __DirectiveLocation { @@ -559,6 +561,17 @@ def test_prints_introspection_schema(): FRAGMENT_DEFINITION FRAGMENT_SPREAD INLINE_FRAGMENT + SCHEMA + SCALAR + OBJECT + FIELD_DEFINITION + ARGUMENT_DEFINITION + INTERFACE + UNION + ENUM + ENUM_VALUE + INPUT_OBJECT + INPUT_FIELD_DEFINITION } type __EnumValue { diff --git a/graphql/utils/tests/test_suggestion_list.py b/graphql/utils/tests/test_suggestion_list.py new file mode 100644 index 00000000..3d99ae72 --- /dev/null +++ b/graphql/utils/tests/test_suggestion_list.py @@ -0,0 +1,15 @@ +from graphql.utils.suggestion_list import suggestion_list + + +def test_returns_results_when_input_is_empty(): + assert suggestion_list('', ['a']) == ['a'] + + +def test_returns_empty_array_when_there_are_no_options(): + assert suggestion_list('input', []) == [] + + +def test_returns_options_sorted_based_on_similarity(): + assert suggestion_list('abc', ['a', 'ab', 'abc']) == ['abc', 'ab'] + + assert suggestion_list('csutomer', ['customer', 'stomer', 'store']) == ['customer', 'stomer', 'store'] diff --git a/graphql/validation/rules/fields_on_correct_type.py b/graphql/validation/rules/fields_on_correct_type.py index 5f92021a..55bb2221 100644 --- a/graphql/validation/rules/fields_on_correct_type.py +++ b/graphql/validation/rules/fields_on_correct_type.py @@ -2,7 +2,10 @@ from ...error import GraphQLError from ...pyutils.ordereddict import OrderedDict -from ...type.definition import is_abstract_type +from ...type.definition import (GraphQLInterfaceType, GraphQLObjectType, + GraphQLUnionType) +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list from .base import ValidationRule try: @@ -13,64 +16,98 @@ izip = zip +def _undefined_field_message(field_name, type, suggested_types, + suggested_fields): + message = 'Cannot query field "{}" on type "{}".'.format(field_name, type) + + if suggested_types: + suggestions = quoted_or_list(suggested_types) + message += " Did you mean to use an inline fragment on {}?".format(suggestions) + elif suggested_fields: + suggestions = quoted_or_list(suggested_fields) + message += " Did you mean {}?".format(suggestions) + + return message + + class OrderedCounter(Counter, OrderedDict): pass class FieldsOnCorrectType(ValidationRule): + '''Fields on correct type + + A GraphQL document is only valid if all fields selected are defined by the + parent type, or are an allowed meta field such as __typenamme + ''' def enter_Field(self, node, key, parent, path, ancestors): - type = self.context.get_parent_type() - if not type: + parent_type = self.context.get_parent_type() + if not parent_type: return field_def = self.context.get_field_def() if not field_def: - # This isn't valid. Let's find suggestions, if any. - suggested_types = [] - if is_abstract_type(type): - schema = self.context.get_schema() - suggested_types = get_sibling_interfaces_including_field(schema, type, node.name.value) - suggested_types += get_implementations_including_field(schema, type, node.name.value) + # This field doesn't exist, lets look for suggestions. + schema = self.context.get_schema() + field_name = node.name.value + + # First determine if there are any suggested types to condition on. + suggested_type_names = get_suggested_type_names(schema, parent_type, field_name) + # if there are no suggested types perhaps it was a typo? + suggested_field_names = [] if suggested_type_names else get_suggested_field_names(schema, parent_type, field_name) + + # report an error including helpful suggestions. self.context.report_error(GraphQLError( - self.undefined_field_message(node.name.value, type.name, suggested_types), + _undefined_field_message(field_name, parent_type.name, suggested_type_names, suggested_field_names), [node] )) - @staticmethod - def undefined_field_message(field_name, type, suggested_types): - message = 'Cannot query field "{}" on type "{}".'.format(field_name, type) - MAX_LENGTH = 5 - if suggested_types: - suggestions = ', '.join(['"{}"'.format(t) for t in suggested_types[:MAX_LENGTH]]) - l_suggested_types = len(suggested_types) - if l_suggested_types > MAX_LENGTH: - suggestions += ", and {} other types".format(l_suggested_types - MAX_LENGTH) - message += " However, this field exists on {}.".format(suggestions) - message += " Perhaps you meant to use an inline fragment?" - return message - - -def get_implementations_including_field(schema, type, field_name): - '''Return implementations of `type` that include `fieldName` as a valid field.''' - return sorted(map(lambda t: t.name, filter(lambda t: field_name in t.fields, schema.get_possible_types(type)))) - - -def get_sibling_interfaces_including_field(schema, type, field_name): - '''Go through all of the implementations of type, and find other interaces - that they implement. If those interfaces include `field` as a valid field, - return them, sorted by how often the implementations include the other - interface.''' - - implementing_objects = schema.get_possible_types(type) - suggested_interfaces = OrderedCounter() - for t in implementing_objects: - for i in t.interfaces: - if field_name not in i.fields: - continue - suggested_interfaces[i.name] += 1 - most_common = suggested_interfaces.most_common() - if not most_common: - return [] - # Get first element of each list (excluding the counter int) - return list(next(izip(*most_common))) + +def get_suggested_type_names(schema, output_type, field_name): + '''Go through all of the implementations of type, as well as the interfaces + that they implement. If any of those types include the provided field, + suggest them, sorted by how often the type is referenced, starting + with Interfaces.''' + + if isinstance(output_type, (GraphQLInterfaceType, GraphQLUnionType)): + suggested_object_types = [] + interface_usage_count = OrderedDict() + for possible_type in schema.get_possible_types(output_type): + if not possible_type.fields.get(field_name): + return + + # This object type defines this field. + suggested_object_types.append(possible_type.name) + + for possible_interface in possible_type.interfaces: + if not possible_interface.fields.get(field_name): + continue + + # This interface type defines this field. + interface_usage_count[possible_interface.name] = ( + interface_usage_count.get(possible_interface.name, 0) + 1) + + # Suggest interface types based on how common they are. + suggested_interface_types = sorted(list(interface_usage_count.keys()), key=lambda k: interface_usage_count[k], + reverse=True) + + # Suggest both interface and object types. + suggested_interface_types.extend(suggested_object_types) + return suggested_interface_types + + # Otherwise, must be an Object type, which does not have possible fields. + return [] + + +def get_suggested_field_names(schema, graphql_type, field_name): + '''For the field name provided, determine if there are any similar field names + that may be the result of a typo.''' + + if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLObjectType)): + possible_field_names = list(graphql_type.fields.keys()) + + return suggestion_list(field_name, possible_field_names) + + # Otherwise, must be a Union type, which does not define fields. + return [] diff --git a/graphql/validation/rules/known_argument_names.py b/graphql/validation/rules/known_argument_names.py index 921097c1..43ace89c 100644 --- a/graphql/validation/rules/known_argument_names.py +++ b/graphql/validation/rules/known_argument_names.py @@ -1,8 +1,26 @@ from ...error import GraphQLError from ...language import ast +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list from .base import ValidationRule +def _unknown_arg_message(arg_name, field_name, type, suggested_args): + message = 'Unknown argument "{}" on field "{}" of type "{}".'.format(arg_name, field_name, type) + if suggested_args: + message += ' Did you mean {}?'.format(quoted_or_list(suggested_args)) + + return message + + +def _unknown_directive_arg_message(arg_name, directive_name, suggested_args): + message = 'Unknown argument "{}" on directive "@{}".'.format(arg_name, directive_name) + if suggested_args: + message += ' Did you mean {}?'.format(quoted_or_list(suggested_args)) + + return message + + class KnownArgumentNames(ValidationRule): def enter_Argument(self, node, key, parent, path, ancestors): @@ -15,11 +33,21 @@ def enter_Argument(self, node, key, parent, path, ancestors): field_arg_def = field_def.args.get(node.name.value) + print(field_def.args.items()) + if not field_arg_def: parent_type = self.context.get_parent_type() assert parent_type self.context.report_error(GraphQLError( - self.unknown_arg_message(node.name.value, argument_of.name.value, parent_type.name), + _unknown_arg_message( + node.name.value, + argument_of.name.value, + parent_type.name, + suggestion_list( + node.name.value, + (arg_name for arg_name in field_def.args.keys()) + ) + ), [node] )) @@ -32,14 +60,13 @@ def enter_Argument(self, node, key, parent, path, ancestors): if not directive_arg_def: self.context.report_error(GraphQLError( - self.unknown_directive_arg_message(node.name.value, directive.name), + _unknown_directive_arg_message( + node.name.value, + directive.name, + suggestion_list( + node.name.value, + (arg_name for arg_name in directive.args.keys()) + ) + ), [node] )) - - @staticmethod - def unknown_arg_message(arg_name, field_name, type): - return 'Unknown argument "{}" on field "{}" of type "{}".'.format(arg_name, field_name, type) - - @staticmethod - def unknown_directive_arg_message(arg_name, directive_name): - return 'Unknown argument "{}" on directive "@{}".'.format(arg_name, directive_name) diff --git a/graphql/validation/rules/known_directives.py b/graphql/validation/rules/known_directives.py index 8e4a21a6..f672105d 100644 --- a/graphql/validation/rules/known_directives.py +++ b/graphql/validation/rules/known_directives.py @@ -18,8 +18,7 @@ def enter_Directive(self, node, key, parent, path, ancestors): [node] )) - applied_to = ancestors[-1] - candidate_location = get_location_for_applied_node(applied_to) + candidate_location = get_directive_location_for_ast_path(ancestors) if not candidate_location: self.context.report_error(GraphQLError( self.misplaced_directive_message(node.name.value, node.type), @@ -47,7 +46,8 @@ def misplaced_directive_message(directive_name, location): } -def get_location_for_applied_node(applied_to): +def get_directive_location_for_ast_path(ancestors): + applied_to = ancestors[-1] if isinstance(applied_to, ast.OperationDefinition): return _operation_definition_map.get(applied_to.operation) @@ -62,3 +62,36 @@ def get_location_for_applied_node(applied_to): elif isinstance(applied_to, ast.FragmentDefinition): return DirectiveLocation.FRAGMENT_DEFINITION + + elif isinstance(applied_to, ast.SchemaDefinition): + return DirectiveLocation.SCHEMA + + elif isinstance(applied_to, ast.ScalarTypeDefinition): + return DirectiveLocation.SCALAR + + elif isinstance(applied_to, ast.ObjectTypeDefinition): + return DirectiveLocation.OBJECT + + elif isinstance(applied_to, ast.FieldDefinition): + return DirectiveLocation.FIELD_DEFINITION + + elif isinstance(applied_to, ast.InterfaceTypeDefinition): + return DirectiveLocation.INTERFACE + + elif isinstance(applied_to, ast.UnionTypeDefinition): + return DirectiveLocation.UNION + + elif isinstance(applied_to, ast.EnumTypeDefinition): + return DirectiveLocation.ENUM + + elif isinstance(applied_to, ast.EnumValueDefinition): + return DirectiveLocation.ENUM_VALUE + + elif isinstance(applied_to, ast.InputObjectTypeDefinition): + return DirectiveLocation.INPUT_OBJECT + + elif isinstance(applied_to, ast.InputValueDefinition): + parent_node = ancestors[-3] + return (DirectiveLocation.INPUT_FIELD_DEFINITION + if isinstance(parent_node, ast.InputObjectTypeDefinition) + else DirectiveLocation.ARGUMENT_DEFINITION) diff --git a/graphql/validation/rules/known_type_names.py b/graphql/validation/rules/known_type_names.py index 8be7c67f..d7d13699 100644 --- a/graphql/validation/rules/known_type_names.py +++ b/graphql/validation/rules/known_type_names.py @@ -1,7 +1,17 @@ from ...error import GraphQLError +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list from .base import ValidationRule +def _unknown_type_message(type, suggested_types): + message = 'Unknown type "{}".'.format(type) + if suggested_types: + message += ' Perhaps you meant {}?'.format(quoted_or_list(suggested_types)) + + return message + + class KnownTypeNames(ValidationRule): def enter_ObjectTypeDefinition(self, node, *args): @@ -17,12 +27,17 @@ def enter_InputObjectTypeDefinition(self, node, *args): return False def enter_NamedType(self, node, *args): + schema = self.context.get_schema() type_name = node.name.value - type = self.context.get_schema().get_type(type_name) + type = schema.get_type(type_name) if not type: - self.context.report_error(GraphQLError(self.unknown_type_message(type_name), [node])) - - @staticmethod - def unknown_type_message(type): - return 'Unknown type "{}".'.format(type) + self.context.report_error( + GraphQLError( + _unknown_type_message( + type_name, + suggestion_list(type_name, list(schema.get_type_map().keys())) + ), + [node] + ) + ) diff --git a/graphql/validation/rules/no_fragment_cycles.py b/graphql/validation/rules/no_fragment_cycles.py index 33a49f04..d2e0d79f 100644 --- a/graphql/validation/rules/no_fragment_cycles.py +++ b/graphql/validation/rules/no_fragment_cycles.py @@ -24,7 +24,7 @@ def detect_cycle_recursive(self, fragment): fragment_name = fragment.name.value self.visited_frags.add(fragment_name) - spread_nodes = self.context.get_fragment_spreads(fragment) + spread_nodes = self.context.get_fragment_spreads(fragment.selection_set) if not spread_nodes: return diff --git a/graphql/validation/rules/overlapping_fields_can_be_merged.py b/graphql/validation/rules/overlapping_fields_can_be_merged.py index 97f0a06b..bc7f51ac 100644 --- a/graphql/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/validation/rules/overlapping_fields_can_be_merged.py @@ -1,9 +1,9 @@ import itertools +from collections import OrderedDict from ...error import GraphQLError from ...language import ast from ...language.printer import print_ast -from ...pyutils.default_ordered_dict import DefaultOrderedDict from ...pyutils.pair_set import PairSet from ...type.definition import (GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, @@ -14,262 +14,466 @@ class OverlappingFieldsCanBeMerged(ValidationRule): - __slots__ = 'compared_set', + __slots__ = ('_compared_fragments', '_cached_fields_and_fragment_names', ) def __init__(self, context): super(OverlappingFieldsCanBeMerged, self).__init__(context) - self.compared_set = PairSet() - - def find_conflicts(self, parent_fields_are_mutually_exclusive, field_map): - conflicts = [] - for response_name, fields in field_map.items(): - field_len = len(fields) - if field_len <= 1: - continue - - for field_a in fields: - for field_b in fields: - conflict = self.find_conflict( - parent_fields_are_mutually_exclusive, - response_name, - field_a, - field_b - ) - if conflict: - conflicts.append(conflict) + # A memoization for when two fragments are compared "between" each other for + # conflicts. Two fragments may be compared many times, so memoizing this can + # dramatically improve the performance of this validator. + self._compared_fragments = PairSet() - return conflicts - - def find_conflict(self, parent_fields_are_mutually_exclusive, response_name, field1, field2): - parent_type1, ast1, def1 = field1 - parent_type2, ast2, def2 = field2 - - # Not a pair - if ast1 is ast2: - return - - # Memoize, do not report the same issue twice. - # Note: Two overlapping ASTs could be encountered both when - # `parentFieldsAreMutuallyExclusive` is true and is false, which could - # produce different results (when `true` being a subset of `false`). - # However we do not need to include this piece of information when - # memoizing since this rule visits leaf fields before their parent fields, - # ensuring that `parentFieldsAreMutuallyExclusive` is `false` the first - # time two overlapping fields are encountered, ensuring that the full - # set of validation rules are always checked when necessary. - - # if parent_type1 != parent_type2 and \ - # isinstance(parent_type1, GraphQLObjectType) and \ - # isinstance(parent_type2, GraphQLObjectType): - # return - - if self.compared_set.has(ast1, ast2): - return - - self.compared_set.add(ast1, ast2) - - # The return type for each field. - type1 = def1 and def1.type - type2 = def2 and def2.type - - # If it is known that two fields could not possibly apply at the same - # time, due to the parent types, then it is safe to permit them to diverge - # in aliased field or arguments used as they will not present any ambiguity - # by differing. - # It is known that two parent types could never overlap if they are - # different Object types. Interface or Union types might overlap - if not - # in the current state of the schema, then perhaps in some future version, - # thus may not safely diverge. - - fields_are_mutually_exclusive = ( - parent_fields_are_mutually_exclusive or ( - parent_type1 != parent_type2 and - isinstance(parent_type1, GraphQLObjectType) and - isinstance(parent_type2, GraphQLObjectType) - ) + # A cache for the "field map" and list of fragment names found in any given + # selection set. Selection sets may be asked for this information multiple + # times, so this improves the performance of this validator. + self._cached_fields_and_fragment_names = {} + + def leave_SelectionSet(self, node, key, parent, path, ancestors): + # Note: we validate on the reverse traversal so deeper conflicts will be + # caught first, for correct calculation of mutual exclusivity and for + # clearer error messages. + # field_map = _collect_field_asts_and_defs( + # self.context, + # self.context.get_parent_type(), + # node + # ) + + # conflicts = _find_conflicts(self.context, False, field_map, self.compared_set) + conflicts = _find_conflicts_within_selection_set(self.context, self._cached_fields_and_fragment_names, + self._compared_fragments, self.context.get_parent_type(), + node) + + for (reason_name, reason), fields1, fields2 in conflicts: + self.context.report_error(GraphQLError( + self.fields_conflict_message(reason_name, reason), + list(fields1) + list(fields2) + )) + + @staticmethod + def same_type(type1, type2): + return is_equal_type(type1, type2) + # return type1.is_same_type(type2) + + @classmethod + def fields_conflict_message(cls, reason_name, reason): + return ( + 'Fields "{}" conflict because {}. ' + 'Use different aliases on the fields to fetch both if this was ' + 'intentional.' + ).format(reason_name, cls.reason_message(reason)) + + @classmethod + def reason_message(cls, reason): + if isinstance(reason, list): + return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason)) + for reason_name, sub_reason in reason) + + return reason + + +# Algorithm: +# +# Conflicts occur when two fields exist in a query which will produce the same +# response name, but represent differing values, thus creating a conflict. +# The algorithm below finds all conflicts via making a series of comparisons +# between fields. In order to compare as few fields as possible, this makes +# a series of comparisons "within" sets of fields and "between" sets of fields. +# +# Given any selection set, a collection produces both a set of fields by +# also including all inline fragments, as well as a list of fragments +# referenced by fragment spreads. +# +# A) Each selection set represented in the document first compares "within" its +# collected set of fields, finding any conflicts between every pair of +# overlapping fields. +# Note: This is the only time that a the fields "within" a set are compared +# to each other. After this only fields "between" sets are compared. +# +# B) Also, if any fragment is referenced in a selection set, then a +# comparison is made "between" the original set of fields and the +# referenced fragment. +# +# C) Also, if multiple fragments are referenced, then comparisons +# are made "between" each referenced fragment. +# +# D) When comparing "between" a set of fields and a referenced fragment, first +# a comparison is made between each field in the original set of fields and +# each field in the the referenced set of fields. +# +# E) Also, if any fragment is referenced in the referenced selection set, +# then a comparison is made "between" the original set of fields and the +# referenced fragment (recursively referring to step D). +# +# F) When comparing "between" two fragments, first a comparison is made between +# each field in the first referenced set of fields and each field in the the +# second referenced set of fields. +# +# G) Also, any fragments referenced by the first must be compared to the +# second, and any fragments referenced by the second must be compared to the +# first (recursively referring to step F). +# +# H) When comparing two fields, if both have selection sets, then a comparison +# is made "between" both selection sets, first comparing the set of fields in +# the first selection set with the set of fields in the second. +# +# I) Also, if any fragment is referenced in either selection set, then a +# comparison is made "between" the other set of fields and the +# referenced fragment. +# +# J) Also, if two fragments are referenced in both selection sets, then a +# comparison is made "between" the two fragments. + +def _find_conflicts_within_selection_set(context, cached_fields_and_fragment_names, compared_fragments, parent_type, + selection_set): + """Find all conflicts found "within" a selection set, including those found via spreading in fragments. + + Called when visiting each SelectionSet in the GraphQL Document. + """ + conflicts = [] + + field_map, fragment_names = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, parent_type, + selection_set) + + # (A) Find all conflicts "within" the fields of this selection set. + # Note: this is the *only place* `collect_conflicts_within` is called. + _collect_conflicts_within( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + field_map + ) + + # (B) Then collect conflicts between these fields and those represented by + # each spread fragment name found. + for i, fragment_name in enumerate(fragment_names): + _collect_conflicts_between_fields_and_fragment( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + False, + field_map, + fragment_name, ) - if not fields_are_mutually_exclusive: - name1 = ast1.name.value - name2 = ast2.name.value - - if name1 != name2: - return ( - (response_name, '{} and {} are different fields'.format(name1, name2)), - [ast1], - [ast2] - ) - - if not self.same_arguments(ast1.arguments, ast2.arguments): - return ( - (response_name, 'they have differing arguments'), - [ast1], - [ast2] - ) - - if type1 and type2 and do_types_conflict(type1, type2): - return ( - (response_name, 'they return conflicting types {} and {}'.format(type1, type2)), - [ast1], - [ast2] + # (C) Then compare this fragment with all other fragments found in this + # selection set to collect conflicts within fragments spread together. + # This compares each item in the list of fragment names to every other item + # in that same list (except for itself). + for other_fragment_name in fragment_names[i+1:]: + _collect_conflicts_between_fragments( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + False, + fragment_name, + other_fragment_name, ) - subfield_map = self.get_subfield_map(ast1, type1, ast2, type2) - if subfield_map: - conflicts = self.find_conflicts(fields_are_mutually_exclusive, subfield_map) - return self.subfield_conflicts(conflicts, response_name, ast1, ast2) + return conflicts - def get_subfield_map(self, ast1, type1, ast2, type2): - selection_set1 = ast1.selection_set - selection_set2 = ast2.selection_set - if selection_set1 and selection_set2: - visited_fragment_names = set() +def _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map, + fragment_name): - subfield_map = self.collect_field_asts_and_defs( - get_named_type(type1), - selection_set1, - visited_fragment_names - ) + fragment = context.get_fragment(fragment_name) + + if not fragment: + return None + + field_map2, fragment_names2 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment) + + # (D) First collect any conflicts between the provided collection of fields + # and the collection of fields represented by the given fragment. + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map, field_map2) + + # (E) Then collect any conflicts between the provided collection of fields + # and any fragment names found in the given fragment. + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map, + fragment_name2) + + +# Collect all conflicts found between two fragments, including via spreading in +# any nested fragments +def _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, fragment_name1, fragment_name2): + + fragment1 = context.get_fragment(fragment_name1) + fragment2 = context.get_fragment(fragment_name2) + + if not fragment1 or not fragment2: + return None + + # No need to compare a fragment to itself. + if fragment1 == fragment2: + return None + + # Memoize so two fragments are not compared for conflicts more than once. + if compared_fragments.has(fragment_name1, fragment_name2, are_mutually_exclusive): + return None + + compared_fragments.add(fragment_name1, fragment_name2, are_mutually_exclusive) + + field_map1, fragment_names1 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment1) + + field_map2, fragment_names2 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment2) + + # (F) First, collect all conflicts between these two collections of fields + # (not including any nested fragments) + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map1, field_map2) + + # (G) Then collect conflicts between the first fragment and any nested + # fragments spread in the second fragment. + for _fragment_name2 in fragment_names2: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, fragment_name1, _fragment_name2) + + # (G) Then collect conflicts between the second fragment and any nested + # fragments spread in the first fragment. + for _fragment_name1 in fragment_names1: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, _fragment_name1, fragment_name2) + + +def _find_conflicts_between_sub_selection_sets(context, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, parent_type1, selection_set1, + parent_type2, selection_set2): + """Find all conflicts found between two selection sets. + + Includes those found via spreading in fragments. Called when determining if conflicts exist + between the sub-fields of two overlapping fields. + """ + + conflicts = [] + + field_map1, fragment_names1 = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + parent_type1, selection_set1) + + field_map2, fragment_names2 = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + parent_type2, selection_set2) + + # (H) First, collect all conflicts between these two collections of field. + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map1, field_map2) + + # (I) Then collect conflicts between the first collection of fields and + # those referenced by each fragment name associated with the second. + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map1, + fragment_name2) + + # (I) Then collect conflicts between the second collection of fields and + # those referenced by each fragment name associated with the first. + for fragment_name1 in fragment_names1: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map2, + fragment_name1) + + # (J) Also collect conflicts between any fragment names by the first and + # fragment names by the second. This compares each item in the first set of + # names to each item in the second set of names. + for fragment_name1 in fragment_names1: + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, + fragment_name1, fragment_name2) + + return conflicts + + +def _collect_conflicts_within(context, conflicts, cached_fields_and_fragment_names, compared_fragments, field_map): + """Collect all Conflicts "within" one collection of fields.""" + + # field map is a keyed collection, where each key represents a response + # name and the value at that key is a list of all fields which provide that + # response name. For every response name, if there are multiple fields, they + # must be compared to find a potential conflict. + for response_name, fields in list(field_map.items()): + # This compares every field in the list to every other field in this list + # (except to itself). If the list only has one item, nothing needs to + # be compared. + for i, field in enumerate(fields): + for other_field in fields[i+1:]: + # within one collection is never mutually exclusive + conflict = _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, False, + response_name, field, other_field) + if conflict: + conflicts.append(conflict) + + +def _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + parent_fields_are_mutually_exclusive, field_map1, field_map2): + """Collect all Conflicts between two collections of fields. + + This is similar to, but different from the `collect_conflicts_within` function above. This check assumes that + `collect_conflicts_within` has already been called on each provided collection of fields. + This is true because this validator traverses each individual selection set. + """ + # A field map is a keyed collection, where each key represents a response + # name and the value at that key is a list of all fields which provide that + # response name. For any response name which appears in both provided field + # maps, each field from the first field map must be compared to every field + # in the second field map to find potential conflicts. + for response_name, fields1 in list(field_map1.items()): + fields2 = field_map2.get(response_name) + + if fields2: + for field1 in fields1: + for field2 in fields2: + conflict = _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, + parent_fields_are_mutually_exclusive, response_name, field1, field2) + + if conflict: + conflicts.append(conflict) + + +def _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, parent_fields_are_mutually_exclusive, + response_name, field1, field2): + """Determines if there is a conflict between two particular fields.""" + parent_type1, ast1, def1 = field1 + parent_type2, ast2, def2 = field2 + + # If it is known that two fields could not possibly apply at the same + # time, due to the parent types, then it is safe to permit them to diverge + # in aliased field or arguments used as they will not present any ambiguity + # by differing. + # It is known that two parent types could never overlap if they are + # different Object types. Interface or Union types might overlap - if not + # in the current state of the schema, then perhaps in some future version, + # thus may not safely diverge. + + are_mutually_exclusive = ( + parent_fields_are_mutually_exclusive or ( + parent_type1 != parent_type2 and + isinstance(parent_type1, GraphQLObjectType) and + isinstance(parent_type2, GraphQLObjectType) + ) + ) + + # The return type for each field. + type1 = def1 and def1.type + type2 = def2 and def2.type - subfield_map = self.collect_field_asts_and_defs( - get_named_type(type2), - selection_set2, - visited_fragment_names, - subfield_map + if not are_mutually_exclusive: + # Two aliases must refer to the same field. + name1 = ast1.name.value + name2 = ast2.name.value + + if name1 != name2: + return ( + (response_name, '{} and {} are different fields'.format(name1, name2)), + [ast1], + [ast2] ) - return subfield_map - def subfield_conflicts(self, conflicts, response_name, ast1, ast2): - if conflicts: + # Two field calls must have the same arguments. + if not _same_arguments(ast1.arguments, ast2.arguments): return ( - (response_name, [conflict[0] for conflict in conflicts]), - tuple(itertools.chain([ast1], *[conflict[1] for conflict in conflicts])), - tuple(itertools.chain([ast2], *[conflict[2] for conflict in conflicts])) + (response_name, 'they have differing arguments'), + [ast1], + [ast2] ) - def leave_SelectionSet(self, node, key, parent, path, ancestors): - # Note: we validate on the reverse traversal so deeper conflicts will be - # caught first, for correct calculation of mutual exclusivity and for - # clearer error messages. - field_map = self.collect_field_asts_and_defs( - self.context.get_parent_type(), - node + if type1 and type2 and do_types_conflict(type1, type2): + return ( + (response_name, 'they return conflicting types {} and {}'.format(type1, type2)), + [ast1], + [ast2] ) - conflicts = self.find_conflicts(False, field_map) - if conflicts: - for (reason_name, reason), fields1, fields2 in conflicts: - self.context.report_error( - GraphQLError( - self.fields_conflict_message( - reason_name, - reason), - list(fields1) + - list(fields2))) + # Collect and compare sub-fields. Use the same "visited fragment names" list + # for both collections so fields in a fragment reference are never + # compared to themselves. + selection_set1 = ast1.selection_set + selection_set2 = ast2.selection_set - @staticmethod - def same_type(type1, type2): - return is_equal_type(type1, type2) - # return type1.is_same_type(type2) + if selection_set1 and selection_set2: + conflicts = _find_conflicts_between_sub_selection_sets(context, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, + get_named_type(type1), selection_set1, + get_named_type(type2), selection_set2) - @staticmethod - def same_value(value1, value2): - return (not value1 and not value2) or print_ast(value1) == print_ast(value2) + return _subfield_conflicts(conflicts, response_name, ast1, ast2) - @classmethod - def same_arguments(cls, arguments1, arguments2): - # Check to see if they are empty arguments or nones. If they are, we can - # bail out early. - if not (arguments1 or arguments2): - return True - if len(arguments1) != len(arguments2): - return False +def _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, parent_type, selection_set): + cached = cached_fields_and_fragment_names.get(selection_set) - arguments2_values_to_arg = {a.name.value: a for a in arguments2} + if not cached: + ast_and_defs = OrderedDict() + fragment_names = OrderedDict() + _collect_fields_and_fragment_names(context, parent_type, selection_set, ast_and_defs, fragment_names) + cached = [ast_and_defs, list(fragment_names.keys())] + cached_fields_and_fragment_names[selection_set] = cached - for argument1 in arguments1: - argument2 = arguments2_values_to_arg.get(argument1.name.value) - if not argument2: - return False + return cached - if not cls.same_value(argument1.value, argument2.value): - return False - return True +def _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, fragment): + """Given a reference to a fragment, return the represented collection of fields as well as a list of + nested fragment names referenced via fragment spreads.""" - def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragment_names=None, ast_and_defs=None): - if visited_fragment_names is None: - visited_fragment_names = set() - - if ast_and_defs is None: - # An ordered dictionary is required, otherwise the error message will be out of order. - # We need to preserve the order that the item was inserted into the dict, as that will dictate - # in which order the reasons in the error message should show. - # Otherwise, the error messages will be inconsistently ordered for the same AST. - # And this can make it so that tests fail half the time, and fool a user into thinking that - # the errors are different, when in-fact they are the same, just that the ordering of the reasons differ. - ast_and_defs = DefaultOrderedDict(list) - - for selection in selection_set.selections: - if isinstance(selection, ast.Field): - field_name = selection.name.value - field_def = None - if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)): - field_def = parent_type.fields.get(field_name) + # Short-circuit building a type from the AST if possible. + cached = cached_fields_and_fragment_names.get(fragment.selection_set) - response_name = selection.alias.value if selection.alias else field_name - ast_and_defs[response_name].append((parent_type, selection, field_def)) + if cached: + return cached - elif isinstance(selection, ast.InlineFragment): - type_condition = selection.type_condition - inline_fragment_type = \ - type_from_ast(self.context.get_schema(), type_condition) \ - if type_condition else parent_type + fragment_type = type_from_ast(context.get_schema(), fragment.type_condition) - self.collect_field_asts_and_defs( - inline_fragment_type, - selection.selection_set, - visited_fragment_names, - ast_and_defs - ) + return _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + fragment_type, fragment.selection_set) - elif isinstance(selection, ast.FragmentSpread): - fragment_name = selection.name.value - if fragment_name in visited_fragment_names: - continue - visited_fragment_names.add(fragment_name) - fragment = self.context.get_fragment(fragment_name) +def _collect_fields_and_fragment_names(context, parent_type, selection_set, ast_and_defs, fragment_names): - if not fragment: - continue + for selection in selection_set.selections: + if isinstance(selection, ast.Field): + field_name = selection.name.value + if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)): + field_def = parent_type.fields.get(field_name) + else: + field_def = None - self.collect_field_asts_and_defs( - type_from_ast(self.context.get_schema(), fragment.type_condition), - fragment.selection_set, - visited_fragment_names, - ast_and_defs - ) + response_name = selection.alias.value if selection.alias else field_name - return ast_and_defs + if not ast_and_defs.get(response_name): + ast_and_defs[response_name] = [] - @classmethod - def fields_conflict_message(cls, reason_name, reason): - return ( - 'Fields "{}" conflict because {}. ' - 'Use different aliases on the fields to fetch both if this was ' - 'intentional.' - ).format(reason_name, cls.reason_message(reason)) + ast_and_defs[response_name].append([parent_type, selection, field_def]) - @classmethod - def reason_message(cls, reason): - if isinstance(reason, list): - return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason)) - for reason_name, sub_reason in reason) + elif isinstance(selection, ast.FragmentSpread): + fragment_names[selection.name.value] = True + elif isinstance(selection, ast.InlineFragment): + type_condition = selection.type_condition + if type_condition: + inline_fragment_type = type_from_ast(context.get_schema(), selection.type_condition) + else: + inline_fragment_type = parent_type - return reason + _collect_fields_and_fragment_names(context, inline_fragment_type, selection.selection_set, ast_and_defs, + fragment_names) + + +def _subfield_conflicts(conflicts, response_name, ast1, ast2): + """Given a series of Conflicts which occurred between two sub-fields, generate a single Conflict.""" + if conflicts: + return ( + (response_name, [conflict[0] for conflict in conflicts]), + tuple(itertools.chain([ast1], *[conflict[1] for conflict in conflicts])), + tuple(itertools.chain([ast2], *[conflict[2] for conflict in conflicts])) + ) def do_types_conflict(type1, type2): @@ -297,3 +501,29 @@ def do_types_conflict(type1, type2): return type1 != type2 return False + + +def _same_value(value1, value2): + return (not value1 and not value2) or print_ast(value1) == print_ast(value2) + + +def _same_arguments(arguments1, arguments2): + # Check to see if they are empty arguments or nones. If they are, we can + # bail out early. + if not (arguments1 or arguments2): + return True + + if len(arguments1) != len(arguments2): + return False + + arguments2_values_to_arg = {a.name.value: a for a in arguments2} + + for argument1 in arguments1: + argument2 = arguments2_values_to_arg.get(argument1.name.value) + if not argument2: + return False + + if not _same_value(argument1.value, argument2.value): + return False + + return True diff --git a/graphql/validation/tests/test_fields_on_correct_type.py b/graphql/validation/tests/test_fields_on_correct_type.py index 7c9376df..da552ad3 100644 --- a/graphql/validation/tests/test_fields_on_correct_type.py +++ b/graphql/validation/tests/test_fields_on_correct_type.py @@ -1,12 +1,13 @@ from graphql.language.location import SourceLocation -from graphql.validation.rules import FieldsOnCorrectType +from graphql.validation.rules.fields_on_correct_type import (FieldsOnCorrectType, + _undefined_field_message) from .utils import expect_fails_rule, expect_passes_rule -def undefined_field(field, type, suggestions, line, column): +def undefined_field(field, gql_type, suggested_types, suggested_fields, line, column): return { - 'message': FieldsOnCorrectType.undefined_field_message(field, type, suggestions), + 'message': _undefined_field_message(field, gql_type, suggested_types, suggested_fields), 'locations': [SourceLocation(line, column)] } @@ -72,8 +73,8 @@ def test_reports_errors_when_type_is_known_again(): } }, ''', [ - undefined_field('unknown_pet_field', 'Pet', [], 3, 9), - undefined_field('unknown_cat_field', 'Cat', [], 5, 13) + undefined_field('unknown_pet_field', 'Pet', [], [], 3, 9), + undefined_field('unknown_cat_field', 'Cat', [], [], 5, 13) ]) @@ -83,7 +84,7 @@ def test_field_not_defined_on_fragment(): meowVolume } ''', [ - undefined_field('meowVolume', 'Dog', [], 3, 9) + undefined_field('meowVolume', 'Dog', [], ['barkVolume'], 3, 9) ]) @@ -95,7 +96,7 @@ def test_ignores_deeply_unknown_field(): } } ''', [ - undefined_field('unknown_field', 'Dog', [], 3, 9) + undefined_field('unknown_field', 'Dog', [], [], 3, 9) ]) @@ -107,7 +108,7 @@ def test_sub_field_not_defined(): } } ''', [ - undefined_field('unknown_field', 'Pet', [], 4, 11) + undefined_field('unknown_field', 'Pet', [], [], 4, 11) ]) @@ -119,7 +120,7 @@ def test_field_not_defined_on_inline_fragment(): } } ''', [ - undefined_field('meowVolume', 'Dog', [], 4, 11) + undefined_field('meowVolume', 'Dog', [], ['barkVolume'], 4, 11) ]) @@ -129,7 +130,7 @@ def test_aliased_field_target_not_defined(): volume : mooVolume } ''', [ - undefined_field('mooVolume', 'Dog', [], 3, 9) + undefined_field('mooVolume', 'Dog', [], ['barkVolume'], 3, 9) ]) @@ -139,7 +140,7 @@ def test_aliased_lying_field_target_not_defined(): barkVolume : kawVolume } ''', [ - undefined_field('kawVolume', 'Dog', [], 3, 9) + undefined_field('kawVolume', 'Dog', [], ['barkVolume'], 3, 9) ]) @@ -149,7 +150,7 @@ def test_not_defined_on_interface(): tailLength } ''', [ - undefined_field('tailLength', 'Pet', [], 3, 9) + undefined_field('tailLength', 'Pet', [], [], 3, 9) ]) @@ -159,7 +160,7 @@ def test_defined_on_implementors_but_not_on_interface(): nickname } ''', [ - undefined_field('nickname', 'Pet', ['Cat', 'Dog'], 3, 9) + undefined_field('nickname', 'Pet', ['Dog', 'Cat'], ['name'], 3, 9) ]) @@ -177,7 +178,7 @@ def test_direct_field_selection_on_union(): directField } ''', [ - undefined_field('directField', 'CatOrDog', [], 3, 9) + undefined_field('directField', 'CatOrDog', [], [], 3, 9) ]) @@ -190,7 +191,8 @@ def test_defined_on_implementors_queried_on_union(): undefined_field( 'name', 'CatOrDog', - ['Being', 'Pet', 'Canine', 'Cat', 'Dog'], + ['Being', 'Pet', 'Canine', 'Dog', 'Cat'], + [], 3, 9 ) @@ -211,24 +213,45 @@ def test_valid_field_in_inline_fragment(): def test_fields_correct_type_no_suggestion(): - message = FieldsOnCorrectType.undefined_field_message('T', 'f', []) - assert message == 'Cannot query field "T" on type "f".' + message = _undefined_field_message('f', 'T', [], []) + assert message == 'Cannot query field "f" on type "T".' -def test_fields_correct_type_no_small_number_suggestions(): - message = FieldsOnCorrectType.undefined_field_message('T', 'f', ['A', 'B']) +def test_works_with_no_small_numbers_of_type_suggestion(): + message = _undefined_field_message('f', 'T', ['A', 'B'], []) assert message == ( - 'Cannot query field "T" on type "f". ' + - 'However, this field exists on "A", "B". ' + - 'Perhaps you meant to use an inline fragment?' + 'Cannot query field "f" on type "T". ' + + 'Did you mean to use an inline fragment on "A" or "B"?' ) -def test_fields_correct_type_lot_suggestions(): - message = FieldsOnCorrectType.undefined_field_message('T', 'f', ['A', 'B', 'C', 'D', 'E', 'F']) +def test_works_with_no_small_numbers_of_field_suggestion(): + message = _undefined_field_message('f', 'T', [], ['z', 'y']) assert message == ( - 'Cannot query field "T" on type "f". ' + - 'However, this field exists on "A", "B", "C", "D", "E", ' + - 'and 1 other types. ' + - 'Perhaps you meant to use an inline fragment?' + 'Cannot query field "f" on type "T". ' + + 'Did you mean "z" or "y"?' + ) + + +def test_only_shows_one_set_of_suggestions_at_a_time_preferring_types(): + message = _undefined_field_message('f', 'T', ['A', 'B'], ['z', 'y']) + assert message == ( + 'Cannot query field "f" on type "T". ' + + 'Did you mean to use an inline fragment on "A" or "B"?' + ) + + +def test_limits_lots_of_type_suggestions(): + message = _undefined_field_message('f', 'T', ['A', 'B', 'C', 'D', 'E', 'F'], []) + assert message == ( + 'Cannot query field "f" on type "T". ' + + 'Did you mean to use an inline fragment on "A", "B", "C", "D" or "E"?' + ) + + +def test_limits_lots_of_field_suggestions(): + message = _undefined_field_message('f', 'T', [], ['z', 'y', 'x', 'w', 'v', 'u']) + assert message == ( + 'Cannot query field "f" on type "T". ' + + 'Did you mean "z", "y", "x", "w" or "v"?' ) diff --git a/graphql/validation/tests/test_known_argument_names.py b/graphql/validation/tests/test_known_argument_names.py index ed8e2569..65630046 100644 --- a/graphql/validation/tests/test_known_argument_names.py +++ b/graphql/validation/tests/test_known_argument_names.py @@ -1,19 +1,21 @@ from graphql.language.location import SourceLocation -from graphql.validation.rules import KnownArgumentNames +from graphql.validation.rules.known_argument_names import (KnownArgumentNames, + _unknown_arg_message, + _unknown_directive_arg_message) from .utils import expect_fails_rule, expect_passes_rule -def unknown_arg(arg_name, field_name, type_name, line, column): +def unknown_arg(arg_name, field_name, type_name, suggested_args, line, column): return { - 'message': KnownArgumentNames.unknown_arg_message(arg_name, field_name, type_name), + 'message': _unknown_arg_message(arg_name, field_name, type_name, suggested_args), 'locations': [SourceLocation(line, column)] } -def unknown_directive_arg(arg_name, directive_name, line, column): +def unknown_directive_arg(arg_name, directive_name, suggested_args, line, column): return { - 'message': KnownArgumentNames.unknown_directive_arg_message(arg_name, directive_name), + 'message': _unknown_directive_arg_message(arg_name, directive_name, suggested_args), 'locations': [SourceLocation(line, column)] } @@ -89,7 +91,7 @@ def test_undirective_args_are_invalid(): dog @skip(unless: true) } ''', [ - unknown_directive_arg('unless', 'skip', 3, 19) + unknown_directive_arg('unless', 'skip', [], 3, 19) ]) @@ -99,7 +101,7 @@ def test_invalid_arg_name(): doesKnowCommand(unknown: true) } ''', [ - unknown_arg('unknown', 'doesKnowCommand', 'Dog', 3, 25) + unknown_arg('unknown', 'doesKnowCommand', 'Dog', [], 3, 25) ]) @@ -109,8 +111,8 @@ def test_unknown_args_amongst_known_args(): doesKnowCommand(whoknows: 1, dogCommand: SIT, unknown: true) } ''', [ - unknown_arg('whoknows', 'doesKnowCommand', 'Dog', 3, 25), - unknown_arg('unknown', 'doesKnowCommand', 'Dog', 3, 55) + unknown_arg('whoknows', 'doesKnowCommand', 'Dog', [], 3, 25), + unknown_arg('unknown', 'doesKnowCommand', 'Dog', [], 3, 55) ]) @@ -129,6 +131,6 @@ def test_unknown_args_deeply(): } } ''', [ - unknown_arg('unknown', 'doesKnowCommand', 'Dog', 4, 27), - unknown_arg('unknown', 'doesKnowCommand', 'Dog', 9, 31) + unknown_arg('unknown', 'doesKnowCommand', 'Dog', [], 4, 27), + unknown_arg('unknown', 'doesKnowCommand', 'Dog', [], 9, 31) ]) diff --git a/graphql/validation/tests/test_known_directives.py b/graphql/validation/tests/test_known_directives.py index 8309f480..bf27b773 100644 --- a/graphql/validation/tests/test_known_directives.py +++ b/graphql/validation/tests/test_known_directives.py @@ -78,24 +78,105 @@ def test_with_many_unknown_directives(): def test_with_well_placed_directives(): expect_passes_rule(KnownDirectives, ''' - query Foo { + query Foo @onQuery{ name @include(if: true) ...Frag @include(if: true) skippedField @skip(if: true) ...SkippedFrag @skip(if: true) } + + mutation Bar @onMutation { + someField + } ''') def test_with_misplaced_directives(): expect_fails_rule(KnownDirectives, ''' query Foo @include(if: true) { - name @operationOnly - ...Frag @operationOnly + name @onQuery + ...Frag @onQuery + } + + mutation Bar @onQuery { + someField } ''', [ misplaced_directive('include', 'QUERY', 2, 17), - misplaced_directive('operationOnly', 'FIELD', 3, 14), - misplaced_directive('operationOnly', 'FRAGMENT_SPREAD', 4, 17), + misplaced_directive('onQuery', 'FIELD', 3, 14), + misplaced_directive('onQuery', 'FRAGMENT_SPREAD', 4, 17), + misplaced_directive('onQuery', 'MUTATION', 7, 20), ]) + + +# within schema language + +def test_within_schema_language_with_well_placed_directives(): + expect_passes_rule(KnownDirectives, ''' + type MyObj implements MyInterface @onObject { + myField(myArg: Int @onArgumentDefinition): String @onFieldDefinition + } + + scalar MyScalar @onScalar + + interface MyInterface @onInterface { + myField(myArg: Int @onArgumentDefinition): String @onFieldDefinition + } + + union MyUnion @onUnion = MyObj | Other + + enum MyEnum @onEnum { + MY_VALUE @onEnumValue + } + + input MyInput @onInputObject { + myField: Int @onInputFieldDefinition + } + + schema @OnSchema { + query: MyQuery + } + ''') + + +def test_within_schema_language_with_misplaced_directives(): + expect_fails_rule(KnownDirectives, ''' + type MyObj implements MyInterface @onInterface { + myField(myArg: Int @onInputFieldDefinition): String @onInputFieldDefinition + } + + scalar MyScalar @onEnum + + interface MyInterface @onObject { + myField(myArg: Int @onInputFieldDefinition): String @onInputFieldDefinition + } + + union MyUnion @onEnumValue = MyObj | Other + + enum MyEnum @onScalar { + MY_VALUE @onUnion + } + + input MyInput @onEnum { + myField: Int @onArgumentDefinition + } + + schema @onObject { + query: MyQuery + } + ''', [ + misplaced_directive('onInterface', 'OBJECT', 2, 43), + misplaced_directive('onInputFieldDefinition', 'ARGUMENT_DEFINITION', 3, 30), + misplaced_directive('onInputFieldDefinition', 'FIELD_DEFINITION', 3, 63), + misplaced_directive('onEnum', 'SCALAR', 6, 25), + misplaced_directive('onObject', 'INTERFACE', 8, 31), + misplaced_directive('onInputFieldDefinition', 'ARGUMENT_DEFINITION', 9, 30), + misplaced_directive('onInputFieldDefinition', 'FIELD_DEFINITION', 9, 63), + misplaced_directive('onEnumValue', 'UNION', 12, 23), + misplaced_directive('onScalar', 'ENUM', 14, 21), + misplaced_directive('onUnion', 'ENUM_VALUE', 15, 20), + misplaced_directive('onEnum', 'INPUT_OBJECT', 18, 23), + misplaced_directive('onArgumentDefinition', 'INPUT_FIELD_DEFINITION', 19, 24), + misplaced_directive('onObject', 'SCHEMA', 22, 16), + ]) diff --git a/graphql/validation/tests/test_known_type_names.py b/graphql/validation/tests/test_known_type_names.py index 17736532..a9f40c19 100644 --- a/graphql/validation/tests/test_known_type_names.py +++ b/graphql/validation/tests/test_known_type_names.py @@ -1,12 +1,13 @@ from graphql.language.location import SourceLocation -from graphql.validation.rules import KnownTypeNames +from graphql.validation.rules.known_type_names import (KnownTypeNames, + _unknown_type_message) from .utils import expect_fails_rule, expect_passes_rule -def unknown_type(type_name, line, column): +def unknown_type(type_name, suggested_types, line, column): return { - 'message': KnownTypeNames.unknown_type_message(type_name), + 'message': _unknown_type_message(type_name, suggested_types), 'locations': [SourceLocation(line, column)] } @@ -36,9 +37,9 @@ def test_unknown_type_names_are_invalid(): name } ''', [ - unknown_type('JumbledUpLetters', 2, 23), - unknown_type('Badger', 5, 25), - unknown_type('Peettt', 8, 29), + unknown_type('JumbledUpLetters', [], 2, 23), + unknown_type('Badger', [], 5, 25), + unknown_type('Peettt', ['Pet'], 8, 29), ]) @@ -60,5 +61,5 @@ def test_ignores_type_definitions(): } } ''', [ - unknown_type('NotInTheSchema', 12, 23), + unknown_type('NotInTheSchema', [], 12, 23), ]) diff --git a/graphql/validation/tests/test_overlapping_fields_can_be_merged.py b/graphql/validation/tests/test_overlapping_fields_can_be_merged.py index 3882ea2d..4e1cfb46 100644 --- a/graphql/validation/tests/test_overlapping_fields_can_be_merged.py +++ b/graphql/validation/tests/test_overlapping_fields_can_be_merged.py @@ -180,31 +180,31 @@ def test_encounters_conflict_in_fragments(): def test_reports_each_conflict_once(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - { + { f1 { - ...A - ...B + ...A + ...B } f2 { - ...B - ...A + ...B + ...A } f3 { - ...A - ...B - x: c + ...A + ...B + x: c } - } - fragment A on Type { + } + fragment A on Type { x: a - } - fragment B on Type { + } + fragment B on Type { x: b - } + } ''', [ fields_conflict('x', 'a and b are different fields', L(18, 9), L(21, 9)), - fields_conflict('x', 'a and c are different fields', L(18, 9), L(14, 13)), - fields_conflict('x', 'b and c are different fields', L(21, 9), L(14, 13)) + fields_conflict('x', 'c and a are different fields', L(14, 11), L(18, 9)), + fields_conflict('x', 'c and b are different fields', L(14, 11), L(21, 9)) ], sort_list=False) @@ -227,20 +227,20 @@ def test_deep_conflict(): def test_deep_conflict_with_multiple_issues(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - { + { field { - x: a - y: c + x: a + y: c } field { - x: b - y: d + x: b + y: d } - } + } ''', [ fields_conflict( 'field', [('x', 'a and b are different fields'), ('y', 'c and d are different fields')], - L(3, 9), L(4, 13), L(5, 13), L(7, 9), L(8, 13), L(9, 13) + L(3, 9), L(4, 11), L(5, 11), L(7, 9), L(8, 11), L(9, 11) ) ], sort_list=False) @@ -292,6 +292,87 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): ], sort_list=False) +def test_reports_deep_conflict_to_nearest_common_ancestor_in_fragments(): + expect_fails_rule(OverlappingFieldsCanBeMerged, ''' + { + field { + ...F + }, + field { + ...F + } + } + fragment F on T { + deepField { + deeperField { + x: a + } + deeperField { + x: b + } + } + deepField { + deeperField { + y + } + } + } + ''', [ + fields_conflict( + 'deeperField', [('x', 'a and b are different fields')], + L(12, 11), L(13, 13), L(15, 11), L(16, 13) + ) + ], sort_list=False) + + +def test_reports_deep_conflict_in_nested_fragments(): + expect_fails_rule(OverlappingFieldsCanBeMerged, ''' + { + field { + ...F + }, + field { + ...I + } + } + fragment F on T { + x: a + ...G + } + fragment G on T { + y: c + } + fragment I on T { + y: d + ...J + } + fragment J on T { + x: b + } + ''', [ + fields_conflict( + 'field', [('x', 'a and b are different fields'), + ('y', 'c and d are different fields')], + L(3, 9), L(11, 9), L(15, 9), L(6, 9), L(22, 9), L(18, 9) + ) + ], sort_list=False) + + +def test_ignores_unknown_fragments(): + expect_passes_rule(OverlappingFieldsCanBeMerged, ''' + { + field + ...Unknown + ...Known + } + + fragment Known on T { + field + ...OtherUnknown + } + ''') + + SomeBox = GraphQLInterfaceType( 'SomeBox', fields=lambda: { @@ -425,6 +506,59 @@ def test_disallows_differing_return_types_despite_no_overlap(): ], sort_list=False) +def test_reports_correctly_when_a_non_exclusive_follows_an_exclusive(): + expect_fails_rule_with_schema(schema, OverlappingFieldsCanBeMerged, ''' + { + someBox { + ... on IntBox { + deepBox { + ...X + } + } + } + someBox { + ... on StringBox { + deepBox { + ...Y + } + } + } + memoed: someBox { + ... on IntBox { + deepBox { + ...X + } + } + } + memoed: someBox { + ... on StringBox { + deepBox { + ...Y + } + } + } + other: someBox { + ...X + } + other: someBox { + ...Y + } + } + fragment X on SomeBox { + scalar + } + fragment Y on SomeBox { + scalar: unrelatedField + } + ''', [ + fields_conflict( + 'other', + [('scalar', 'scalar and unrelatedField are different fields')], + L(31, 11), L(39, 11), L(34, 11), L(42, 11), + ) + ], sort_list=False) + + def test_disallows_differing_return_type_nullability_despite_no_overlap(): expect_fails_rule_with_schema(schema, OverlappingFieldsCanBeMerged, ''' { @@ -614,10 +748,10 @@ def test_compares_deep_types_including_list(): } ''', [ fields_conflict( - 'edges', [['node', [['id', 'id and name are different fields']]]], - L(14, 9), L(15, 13), - L(16, 17), L(5, 13), - L(6, 17), L(7, 21), + 'edges', [['node', [['id', 'name and id are different fields']]]], + L(5, 13), L(6, 17), + L(7, 21), L(14, 9), + L(15, 13), L(16, 17), ) ], sort_list=False) diff --git a/graphql/validation/tests/test_validation.py b/graphql/validation/tests/test_validation.py index ead25d2d..a637d69a 100644 --- a/graphql/validation/tests/test_validation.py +++ b/graphql/validation/tests/test_validation.py @@ -50,6 +50,6 @@ def test_validates_using_a_custom_type_info(): ) assert len(errors) == 3 - assert errors[0].message == 'Cannot query field "catOrDog" on type "QueryRoot".' - assert errors[1].message == 'Cannot query field "furColor" on type "Cat".' - assert errors[2].message == 'Cannot query field "isHousetrained" on type "Dog".' + assert errors[0].message == 'Cannot query field "catOrDog" on type "QueryRoot". Did you mean "catOrDog"?' + assert errors[1].message == 'Cannot query field "furColor" on type "Cat". Did you mean "furColor"?' + assert errors[2].message == 'Cannot query field "isHousetrained" on type "Dog". Did you mean "isHousetrained"?' diff --git a/graphql/validation/tests/utils.py b/graphql/validation/tests/utils.py index 42967929..0044e012 100644 --- a/graphql/validation/tests/utils.py +++ b/graphql/validation/tests/utils.py @@ -7,7 +7,8 @@ GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString, GraphQLUnionType) -from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, +from graphql.type.directives import (DirectiveLocation, GraphQLDirective, + GraphQLIncludeDirective, GraphQLSkipDirective) from graphql.validation import validate @@ -178,9 +179,26 @@ test_schema = GraphQLSchema( query=QueryRoot, directives=[ - GraphQLDirective(name='operationOnly', locations=['QUERY']), GraphQLIncludeDirective, - GraphQLSkipDirective + GraphQLSkipDirective, + GraphQLDirective(name='onQuery', locations=[DirectiveLocation.QUERY]), + GraphQLDirective(name='onMutation', locations=[DirectiveLocation.MUTATION]), + GraphQLDirective(name='onSubscription', locations=[DirectiveLocation.SUBSCRIPTION]), + GraphQLDirective(name='onField', locations=[DirectiveLocation.FIELD]), + GraphQLDirective(name='onFragmentDefinition', locations=[DirectiveLocation.FRAGMENT_DEFINITION]), + GraphQLDirective(name='onFragmentSpread', locations=[DirectiveLocation.FRAGMENT_SPREAD]), + GraphQLDirective(name='onInlineFragment', locations=[DirectiveLocation.INLINE_FRAGMENT]), + GraphQLDirective(name='OnSchema', locations=[DirectiveLocation.SCHEMA]), + GraphQLDirective(name='onScalar', locations=[DirectiveLocation.SCALAR]), + GraphQLDirective(name='onObject', locations=[DirectiveLocation.OBJECT]), + GraphQLDirective(name='onFieldDefinition', locations=[DirectiveLocation.FIELD_DEFINITION]), + GraphQLDirective(name='onArgumentDefinition', locations=[DirectiveLocation.ARGUMENT_DEFINITION]), + GraphQLDirective(name='onInterface', locations=[DirectiveLocation.INTERFACE]), + GraphQLDirective(name='onUnion', locations=[DirectiveLocation.UNION]), + GraphQLDirective(name='onEnum', locations=[DirectiveLocation.ENUM]), + GraphQLDirective(name='onEnumValue', locations=[DirectiveLocation.ENUM_VALUE]), + GraphQLDirective(name='onInputObject', locations=[DirectiveLocation.INPUT_OBJECT]), + GraphQLDirective(name='onInputFieldDefinition', locations=[DirectiveLocation.INPUT_FIELD_DEFINITION]), ], types=[Cat, Dog, Human, Alien] ) @@ -188,7 +206,7 @@ def expect_valid(schema, rules, query): errors = validate(schema, parse(query), rules) - assert errors == [], 'Should validate' + assert errors == [], 'Error: %s, Should validate' % errors def sort_lists(value): @@ -210,6 +228,7 @@ def expect_invalid(schema, rules, query, expected_errors, sort_list=True): {'line': loc.line, 'column': loc.column} for loc in error['locations'] ] + if sort_list: assert sort_lists(list(map(format_error, errors))) == sort_lists(expected_errors) diff --git a/graphql/validation/validation.py b/graphql/validation/validation.py index cdfaebc8..3610dbf4 100644 --- a/graphql/validation/validation.py +++ b/graphql/validation/validation.py @@ -96,7 +96,7 @@ def get_recursively_referenced_fragments(self, operation): if not fragments: fragments = [] collected_names = set() - nodes_to_visit = [operation] + nodes_to_visit = [operation.selection_set] while nodes_to_visit: node = nodes_to_visit.pop() spreads = self.get_fragment_spreads(node) @@ -107,7 +107,7 @@ def get_recursively_referenced_fragments(self, operation): fragment = self.get_fragment(frag_name) if fragment: fragments.append(fragment) - nodes_to_visit.append(fragment) + nodes_to_visit.append(fragment.selection_set) self._recursively_referenced_fragments[operation] = fragments return fragments @@ -115,7 +115,7 @@ def get_fragment_spreads(self, node): spreads = self._fragment_spreads.get(node) if not spreads: spreads = [] - sets_to_visit = [node.selection_set] + sets_to_visit = [node] while sets_to_visit: _set = sets_to_visit.pop() for selection in _set.selections: diff --git a/tests/starwars/starwars_schema.py b/tests/starwars/starwars_schema.py index 5e875401..52b3fb5f 100644 --- a/tests/starwars/starwars_schema.py +++ b/tests/starwars/starwars_schema.py @@ -142,5 +142,4 @@ ), } ) - StarWarsSchema = GraphQLSchema(query=queryType, types=[humanType, droidType]) diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index bafa2328..538297cd 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -97,10 +97,12 @@ def test_allows_object_fields_in_inline_fragments(): query DroidFieldInFragment { hero { name - ... on Droid { - primaryFunction - } + ...DroidFields } } + + fragment DroidFields on Droid { + primaryFunction + } ''' assert not validation_errors(query)