From 8de60872f639cef3133aba49bae145b9acf4beb7 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 9 Apr 2016 02:07:42 -0700 Subject: [PATCH 01/16] Implement extendSchema. Fixed #40. Related GraphQL-js implementation: https://github.com/graphql/graphql-js/commit/9ea8196c2af97551bab5cfd57201c29e3085ee4e --- graphql/core/utils/extend_schema.py | 345 ++++++++++++++ tests/core_utils/test_extend_schema.py | 601 +++++++++++++++++++++++++ 2 files changed, 946 insertions(+) create mode 100644 graphql/core/utils/extend_schema.py create mode 100644 tests/core_utils/test_extend_schema.py diff --git a/graphql/core/utils/extend_schema.py b/graphql/core/utils/extend_schema.py new file mode 100644 index 00000000..1a444e29 --- /dev/null +++ b/graphql/core/utils/extend_schema.py @@ -0,0 +1,345 @@ +from collections import OrderedDict, defaultdict + +from graphql.core.language import ast + +from ..error import GraphQLError +from ..type.definition import (GraphQLArgument, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLUnionType) +from ..type.scalars import (GraphQLBoolean, GraphQLFloat, GraphQLID, + GraphQLInt, GraphQLString) +from ..type.schema import GraphQLSchema +from .value_from_ast import value_from_ast + + +def extend_schema(schema, documentAST=None): + """Produces a new schema given an existing schema and a document which may + contain GraphQL type extensions and definitions. The original schema will + remain unaltered. + + Because a schema represents a graph of references, a schema cannot be + extended without effectively making an entire copy. We do not know until it's + too late if subgraphs remain unchanged. + + This algorithm copies the provided schema, applying extensions while + producing the copy. The original schema remains unaltered.""" + + assert isinstance( + schema, GraphQLSchema), 'Must provide valid GraphQLSchema' + assert documentAST and isinstance( + documentAST, ast.Document), 'Must provide valid Document AST' + + # Collect the type definitions and extensions found in the document. + type_definition_map = {} + type_extensions_map = defaultdict(list) + + for _def in documentAST.definitions: + if isinstance(_def, ( + ast.ObjectTypeDefinition, + ast.InterfaceTypeDefinition, + ast.EnumTypeDefinition, + ast.UnionTypeDefinition, + ast.ScalarTypeDefinition, + ast.InputObjectTypeDefinition, + )): + # Sanity check that none of the defined types conflict with the + # schema's existing types. + type_name = _def.name.value + if schema.get_type(type_name): + raise GraphQLError( + ('Type "{}" already exists in the schema. It cannot also ' + + 'be defined in this type definition.').format(type_name), + [_def] + ) + + type_definition_map[type_name] = _def + elif isinstance(_def, ast.TypeExtensionDefinition): + # Sanity check that this type extension exists within the + # schema's existing types. + extended_type_name = _def.definition.name.value + existing_type = schema.get_type(extended_type_name) + if not existing_type: + raise GraphQLError( + ('Cannot extend type "{}" because it does not ' + + 'exist in the existing schema.').format(extended_type_name), + [_def.definition] + ) + if not isinstance(existing_type, GraphQLObjectType): + raise GraphQLError( + 'Cannot extend non-object type "{}".'.format( + extended_type_name), + [_def.definition] + ) + + type_extensions_map[extended_type_name].append(_def) + + # Below are functions used for producing this schema that have closed over + # this scope and have access to the schema, cache, and newly defined types. + + def get_type_from_def(type_def): + type = _get_named_type(type_def.name) + assert type, 'Invalid schema' + return type + + def get_type_from_AST(astNode): + type = _get_named_type(astNode.name.value) + if not type: + raise GraphQLError( + ('Unknown type: "{}". Ensure that this type exists ' + + 'either in the original schema, or is added in a type definition.').format( + astNode.name.value), + [astNode] + ) + return type + + # Given a name, returns a type from either the existing schema or an + # added type. + def _get_named_type(typeName): + cached_type_def = type_def_cache.get(typeName) + if cached_type_def: + return cached_type_def + + existing_type = schema.get_type(typeName) + if existing_type: + type_def = extend_type(existing_type) + type_def_cache[typeName] = type_def + return type_def + + type_ast = type_definition_map.get(typeName) + if type_ast: + type_def = build_type(type_ast) + type_def_cache[typeName] = type_def + return type_def + + # Given a type's introspection result, construct the correct + # GraphQLType instance. + def extend_type(type): + if isinstance(type, GraphQLObjectType): + return extend_object_type(type) + if isinstance(type, GraphQLInterfaceType): + return extend_interface_type(type) + if isinstance(type, GraphQLUnionType): + return extend_union_type(type) + return type + + def extend_object_type(type): + return GraphQLObjectType( + name=type.name, + description=type.description, + interfaces=lambda: extend_implemented_interfaces(type), + fields=lambda: extend_field_map(type), + ) + + def extend_interface_type(type): + return GraphQLInterfaceType( + name=type.name, + description=type.description, + fields=lambda: extend_field_map(type), + resolve_type=raise_client_schema_execution_error, + ) + + def extend_union_type(type): + return GraphQLUnionType( + name=type.name, + description=type.description, + types=map(get_type_from_def, type.get_possible_types()), + resolve_type=raise_client_schema_execution_error, + ) + + def extend_implemented_interfaces(type): + interfaces = map(get_type_from_def, type.get_interfaces()) + + # If there are any extensions to the interfaces, apply those here. + extensions = type_extensions_map[type.name] + for extension in extensions: + for namedType in extension.definition.interfaces: + interface_name = namedType.name.value + if any([_def.name == interface_name for _def in interfaces]): + raise GraphQLError( + ('Type "{}" already implements "{}". ' + + 'It cannot also be implemented in this type extension.').format( + type.name, interface_name), + [namedType] + ) + interfaces.append(get_type_from_AST(namedType)) + + return interfaces + + def extend_field_map(type): + new_field_map = OrderedDict() + old_field_map = type.get_fields() + for field_name, field in old_field_map.iteritems(): + new_field_map[field_name] = GraphQLField( + extend_field_type(field.type), + description=field.description, + deprecation_reason=field.deprecation_reason, + args={arg.name: arg for arg in field.args}, + resolver=raise_client_schema_execution_error, + ) + + # If there are any extensions to the fields, apply those here. + extensions = type_extensions_map[type.name] + for extension in extensions: + for field in extension.definition.fields: + field_name = field.name.value + if field_name in old_field_map: + raise GraphQLError( + ('Field "{}.{}" already exists in the ' + + 'schema. It cannot also be defined in this type extension.').format( + type.name, field_name), + [field] + ) + new_field_map[field_name] = GraphQLField( + build_field_type(field.type), + args=build_input_values(field.arguments), + resolver=raise_client_schema_execution_error, + ) + + return new_field_map + + def extend_field_type(type): + if isinstance(type, GraphQLList): + return GraphQLList(extend_field_type(type.of_type)) + if isinstance(type, GraphQLNonNull): + return GraphQLNonNull(extend_field_type(type.of_type)) + return get_type_from_def(type) + + def build_type(type_ast): + _type_build = { + ast.ObjectTypeDefinition: build_object_type, + ast.InterfaceTypeDefinition: build_interface_type, + ast.UnionTypeDefinition: build_union_type, + ast.ScalarTypeDefinition: build_scalar_type, + ast.EnumTypeDefinition: build_enum_type, + ast.InputObjectTypeDefinition: build_input_object_type + } + func = _type_build.get(type(type_ast)) + if func: + return func(type_ast) + + def build_object_type(type_ast): + return GraphQLObjectType( + type_ast.name.value, + interfaces=lambda: build_implemented_interfaces(type_ast), + fields=lambda: build_field_map(type_ast), + ) + + def build_interface_type(type_ast): + return GraphQLInterfaceType( + type_ast.name.value, + fields=lambda: build_field_map(type_ast), + resolve_type=raise_client_schema_execution_error, + ) + + def build_union_type(type_ast): + return GraphQLUnionType( + type_ast.name.value, + types=map(get_type_from_AST, type_ast.types), + resolve_type=raise_client_schema_execution_error, + ) + + def build_scalar_type(type_ast): + return GraphQLScalarType( + type_ast.name.value, + serialize=lambda *args, **kwargs: None, + # Note: validation calls the parse functions to determine if a + # literal value is correct. Returning null would cause use of custom + # scalars to always fail validation. Returning false causes them to + # always pass validation. + parse_value=lambda *args, **kwargs: False, + parse_literal=lambda *args, **kwargs: False, + ) + + def build_enum_type(type_ast): + return GraphQLEnumType( + type_ast.name.value, + values={v.name.value: GraphQLEnumValue() for v in type_ast.values}, + ) + + def build_input_object_type(type_ast): + return GraphQLInputObjectType( + type_ast.name.value, + fields=lambda: build_input_values( + type_ast.fields, GraphQLInputObjectField), + ) + + def build_implemented_interfaces(type_ast): + return map(get_type_from_AST, type_ast.interfaces) + + def build_field_map(type_ast): + return { + field.name.value: GraphQLField( + build_field_type(field.type), + args=build_input_values(field.arguments), + resolver=raise_client_schema_execution_error, + ) for field in type_ast.fields + } + + def build_input_values(values, input_type=GraphQLArgument): + input_values = OrderedDict() + for value in values: + type = build_field_type(value.type) + input_values[value.name.value] = input_type( + type, + default_value=value_from_ast(value.default_value, type) + ) + return input_values + + def build_field_type(type_ast): + if isinstance(type_ast, ast.ListType): + return GraphQLList(build_field_type(type_ast.type)) + if isinstance(type_ast, ast.NonNullType): + return GraphQLNonNull(build_field_type(type_ast.type)) + return get_type_from_AST(type_ast) + + # If this document contains no new types, then return the same unmodified + # GraphQLSchema instance. + if not type_extensions_map and not type_definition_map: + return schema + + # A cache to use to store the actual GraphQLType definition objects by name. + # Initialize to the GraphQL built in scalars. All functions below are inline + # so that this type def cache is within the scope of the closure. + type_def_cache = { + 'String': GraphQLString, + 'Int': GraphQLInt, + 'Float': GraphQLFloat, + 'Boolean': GraphQLBoolean, + 'ID': GraphQLID, + } + + # Get the root Query, Mutation, and Subscription types. + query_type = get_type_from_def(schema.get_query_type()) + + existing_mutation_type = schema.get_mutation_type() + mutationType = existing_mutation_type and get_type_from_def( + existing_mutation_type) or None + + existing_subscription_type = schema.get_subscription_type() + subscription_type = existing_subscription_type and get_type_from_def( + existing_subscription_type) or None + + # Iterate through all types, getting the type definition for each, ensuring + # that any type not directly referenced by a field will get created. + for typeName, _def in schema.get_type_map().iteritems(): + get_type_from_def(_def) + + # Do the same with new types. + for typeName, _def in type_definition_map.iteritems(): + get_type_from_AST(_def) + + # Then produce and return a Schema with these types. + return GraphQLSchema( + query=query_type, + mutation=mutationType, + subscription=subscription_type, + # Copy directives. + directives=schema.get_directives(), + ) + + +def raise_client_schema_execution_error(*args, **kwargs): + raise Exception('Client Schema cannot be used for execution.') diff --git a/tests/core_utils/test_extend_schema.py b/tests/core_utils/test_extend_schema.py new file mode 100644 index 00000000..cf7bc7c3 --- /dev/null +++ b/tests/core_utils/test_extend_schema.py @@ -0,0 +1,601 @@ +from collections import OrderedDict + +from pytest import raises + +from graphql.core import parse +from graphql.core.execution import execute +from graphql.core.type import (GraphQLArgument, GraphQLField, GraphQLID, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString, GraphQLUnionType) +from graphql.core.utils.extend_schema import extend_schema +from graphql.core.utils.schema_printer import print_schema + +# Test schema. +SomeInterfaceType = GraphQLInterfaceType( + name='SomeInterface', + resolve_type=lambda: FooType, + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLString)), + ('some', GraphQLField(SomeInterfaceType)), + ]) +) + + +FooType = GraphQLObjectType( + name='Foo', + interfaces=[SomeInterfaceType], + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLString)), + ('some', GraphQLField(SomeInterfaceType)), + ('tree', GraphQLField(GraphQLNonNull(GraphQLList(FooType)))), + ]) +) + +BarType = GraphQLObjectType( + name='Bar', + interfaces=[SomeInterfaceType], + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLString)), + ('some', GraphQLField(SomeInterfaceType)), + ('foo', GraphQLField(FooType)), + ]) +) + +BizType = GraphQLObjectType( + name='Biz', + fields=lambda: OrderedDict([ + ('fizz', GraphQLField(GraphQLString)), + ]) +) + +SomeUnionType = GraphQLUnionType( + name='SomeUnion', + resolve_type=lambda: FooType, + types=[FooType, BizType], +) + +test_schema = GraphQLSchema( + query=GraphQLObjectType( + name='Query', + fields=lambda: OrderedDict([ + ('foo', GraphQLField(FooType)), + ('someUnion', GraphQLField(SomeUnionType)), + ('someInterface', GraphQLField( + SomeInterfaceType, + args={ + 'id': GraphQLArgument(GraphQLNonNull(GraphQLID)) + }, + )), + ]) + ) +) + + +def test_returns_original_schema_if_no_type_definitions(): + ast = parse('{ field }') + extended_schema = extend_schema(test_schema, ast) + assert extended_schema == test_schema + + +def test_extends_without_altering_original_schema(): + ast = parse(''' + extend type Query { + newField: String + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert 'newField' in print_schema(extended_schema) + assert 'newField' not in print_schema(test_schema) + + +def test_cannot_be_used_for_execution(): + ast = parse(''' + extend type Query { + newField: String + } + ''') + extended_schema = extend_schema(test_schema, ast) + clientQuery = parse('{ newField }') + + result = execute(extended_schema, object(), clientQuery) + assert result.data['newField'] is None + assert str(result.errors[0] + ) == 'Client Schema cannot be used for execution.' + + +def test_extends_objects_by_adding_new_fields(): + ast = parse(''' + extend type Foo { + newField: String + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + # print original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz { + fizz: String +} + +type Foo implements SomeInterface { + name: String + some: SomeInterface + tree: [Foo]! + newField: String +} + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_extends_objects_by_adding_new_fields_with_arguments(): + ast = parse(''' + extend type Foo { + newField(arg1: String, arg2: NewInputObj!): String + } + input NewInputObj { + field1: Int + field2: [Float] + field3: String! + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz { + fizz: String +} + +type Foo implements SomeInterface { + name: String + some: SomeInterface + tree: [Foo]! + newField(arg1: String, arg2: NewInputObj!): String +} + +input NewInputObj { + field1: Int + field2: [Float] + field3: String! +} + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_extends_objects_by_adding_implemented_interfaces(): + ast = parse(''' + extend type Biz implements SomeInterface { + name: String + some: SomeInterface + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz implements SomeInterface { + fizz: String + name: String + some: SomeInterface +} + +type Foo implements SomeInterface { + name: String + some: SomeInterface + tree: [Foo]! +} + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_extends_objects_by_adding_implemented_interfaces(): + ast = parse(''' + extend type Foo { + newObject: NewObject + newInterface: NewInterface + newUnion: NewUnion + newScalar: NewScalar + newEnum: NewEnum + newTree: [Foo]! + } + type NewObject implements NewInterface { + baz: String + } + type NewOtherObject { + fizz: Int + } + interface NewInterface { + baz: String + } + union NewUnion = NewObject | NewOtherObject + scalar NewScalar + enum NewEnum { + OPTION_A + OPTION_B + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz { + fizz: String +} + +type Foo implements SomeInterface { + name: String + some: SomeInterface + tree: [Foo]! + newObject: NewObject + newInterface: NewInterface + newUnion: NewUnion + newScalar: NewScalar + newEnum: NewEnum + newTree: [Foo]! +} + +enum NewEnum { + OPTION_A + OPTION_B +} + +interface NewInterface { + baz: String +} + +type NewObject implements NewInterface { + baz: String +} + +type NewOtherObject { + fizz: Int +} + +scalar NewScalar + +union NewUnion = NewObject | NewOtherObject + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_extends_objects_by_adding_implemented_new_interfaces(): + ast = parse(''' + extend type Foo implements NewInterface { + baz: String + } + interface NewInterface { + baz: String + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz { + fizz: String +} + +type Foo implements SomeInterface, NewInterface { + name: String + some: SomeInterface + tree: [Foo]! + baz: String +} + +interface NewInterface { + baz: String +} + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_extends_objects_multiple_times(): + ast = parse(''' + extend type Biz implements NewInterface { + buzz: String + } + extend type Biz implements SomeInterface { + name: String + some: SomeInterface + newFieldA: Int + } + extend type Biz { + newFieldA: Int + newFieldB: Float + } + interface NewInterface { + buzz: String + } + ''') + original_print = print_schema(test_schema) + extended_schema = extend_schema(test_schema, ast) + assert extend_schema != test_schema + assert print_schema(test_schema) == original_print + assert print_schema(extended_schema) == \ + '''type Bar implements SomeInterface { + name: String + some: SomeInterface + foo: Foo +} + +type Biz implements NewInterface, SomeInterface { + fizz: String + buzz: String + name: String + some: SomeInterface + newFieldA: Int + newFieldB: Float +} + +type Foo implements SomeInterface { + name: String + some: SomeInterface + tree: [Foo]! +} + +interface NewInterface { + buzz: String +} + +type Query { + foo: Foo + someUnion: SomeUnion + someInterface(id: ID!): SomeInterface +} + +interface SomeInterface { + name: String + some: SomeInterface +} + +union SomeUnion = Foo | Biz +''' + + +def test_may_extend_mutations_and_subscriptions(): + mutationSchema = GraphQLSchema( + query=GraphQLObjectType( + 'Query', + fields=lambda: { + 'queryField': GraphQLField(GraphQLString), + } + ), + mutation=GraphQLObjectType( + 'Mutation', + fields={ + 'mutationField': GraphQLField(GraphQLString), + } + ), + subscription=GraphQLObjectType( + 'Subscription', + fields={ + 'subscriptionField': GraphQLField(GraphQLString), + } + ), + ) + + ast = parse(''' + extend type Query { + newQueryField: Int + } + extend type Mutation { + newMutationField: Int + } + extend type Subscription { + newSubscriptionField: Int + } + ''') + original_print = print_schema(mutationSchema) + extended_schema = extend_schema(mutationSchema, ast) + assert extend_schema != mutationSchema + assert print_schema(mutationSchema) == original_print + assert print_schema(extended_schema) == \ + '''type Mutation { + mutationField: String + newMutationField: Int +} + +type Query { + queryField: String + newQueryField: Int +} + +type Subscription { + subscriptionField: String + newSubscriptionField: Int +} +''' + + +def test_does_not_allow_replacing_an_existing_type(): + ast = parse(''' + type Bar { + baz: String + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == \ + ('Type "Bar" already exists in the schema. It cannot also be defined ' + + 'in this type definition.') + + +def test_does_not_allow_replacing_an_existing_field(): + ast = parse(''' + extend type Bar { + foo: Foo + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == \ + ('Field "Bar.foo" already exists in the schema. It cannot also be ' + + 'defined in this type extension.') + + +def test_does_not_allow_replacing_an_existing_interface(): + ast = parse(''' + extend type Foo implements SomeInterface { + otherField: String + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == \ + ('Type "Foo" already implements "SomeInterface". It cannot also be ' + + 'implemented in this type extension.') + + +def test_does_not_allow_referencing_an_unknown_type(): + ast = parse(''' + extend type Bar { + quix: Quix + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == \ + ('Unknown type: "Quix". Ensure that this type exists either in the ' + + 'original schema, or is added in a type definition.') + + +def test_does_not_allow_extending_an_unknown_type(): + ast = parse(''' + extend type UnknownType { + baz: String + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == \ + ('Cannot extend type "UnknownType" because it does not exist in the ' + + 'existing schema.') + + +def test_does_not_allow_extending_an_interface(): + ast = parse(''' + extend type SomeInterface { + baz: String + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == 'Cannot extend non-object type "SomeInterface".' + + +def test_does_not_allow_extending_a_scalar(): + ast = parse(''' + extend type String { + baz: String + } + ''') + with raises(Exception) as exc_info: + extend_schema(test_schema, ast) + + assert str(exc_info.value) == 'Cannot extend non-object type "String".' From 062a6af33ad91da03641d4b34a740446c691b7fd Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 9 Apr 2016 02:57:16 -0700 Subject: [PATCH 02/16] [Validation] Un-interleave overlapping field messages Related GraphQL-js commit https://github.com/graphql/graphql-js/tree/228215a704e9d7f67078fc2652eafa6b6e22026f --- .../rules/overlapping_fields_can_be_merged.py | 19 ++++++++++++------- .../test_overlapping_fields_can_be_merged.py | 14 +++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index 05f66dc6..acd560d2 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -51,7 +51,8 @@ def find_conflict(self, response_name, pair1, pair2): if name1 != name2: return ( (response_name, '{} and {} are different fields'.format(name1, name2)), - (ast1, ast2) + [ast1], + [ast2] ) type1 = def1 and def1.type @@ -60,19 +61,22 @@ def find_conflict(self, response_name, pair1, pair2): if type1 and type2 and not self.same_type(type1, type2): return ( (response_name, 'they return differing types {} and {}'.format(type1, type2)), - (ast1, ast2) + [ast1], + [ast2] ) if not self.same_arguments(ast1.arguments, ast2.arguments): return ( (response_name, 'they have differing arguments'), - (ast1, ast2) + [ast1], + [ast2] ) if not self.same_directives(ast1.directives, ast2.directives): return ( (response_name, 'they have differing directives'), - (ast1, ast2) + [ast1], + [ast2] ) selection_set1 = ast1.selection_set @@ -98,7 +102,8 @@ def find_conflict(self, response_name, pair1, pair2): if conflicts: return ( (response_name, [conflict[0] for conflict in conflicts]), - tuple(itertools.chain((ast1, ast2), *[conflict[1] for conflict in conflicts])) + tuple(itertools.chain([ast1], *[conflict[1] for conflict in conflicts])), + tuple(itertools.chain([ast2], *[conflict[2] for conflict in conflicts])) ) def leave_SelectionSet(self, node, key, parent, path, ancestors): @@ -110,8 +115,8 @@ def leave_SelectionSet(self, node, key, parent, path, ancestors): conflicts = self.find_conflicts(field_map) if conflicts: return [ - GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields)) for - (reason_name, reason), fields in conflicts + GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields1)+list(fields2)) for + (reason_name, reason), fields1, fields2 in conflicts ] @staticmethod diff --git a/tests/core_validation/test_overlapping_fields_can_be_merged.py b/tests/core_validation/test_overlapping_fields_can_be_merged.py index e7f2764c..ec353830 100644 --- a/tests/core_validation/test_overlapping_fields_can_be_merged.py +++ b/tests/core_validation/test_overlapping_fields_can_be_merged.py @@ -205,7 +205,7 @@ def test_deep_conflict(): ''', [ fields_conflict( 'field', [('x', 'a and b are different fields')], - L(3, 9), L(6, 9), L(4, 13), L(7, 13)) + L(3, 9), L(4, 13), L(6, 9), L(7, 13)) ], sort_list=False) @@ -224,7 +224,7 @@ def test_deep_conflict_with_multiple_issues(): ''', [ fields_conflict( 'field', [('x', 'a and b are different fields'), ('y', 'c and d are different fields')], - L(3, 9), L(7, 9), L(4, 13), L(8, 13), L(5, 13), L(9, 13) + L(3, 9), L(4, 13), L(5, 13), L(7, 9), L(8, 13), L(9, 13) ) ], sort_list=False) @@ -246,7 +246,7 @@ def test_very_deep_conflict(): ''', [ fields_conflict( 'field', [['deepField', [['x', 'a and b are different fields']]]], - L(3, 9), L(8, 9), L(4, 13), L(9, 13), L(5, 17), L(10, 17) + L(3, 9), L(4, 13), L(5, 17), L(8, 9), L(9, 13), L(10, 17) ) ], sort_list=False) @@ -271,7 +271,7 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): ''', [ fields_conflict( 'deepField', [('x', 'a and b are different fields')], - L(4, 13), L(7, 13), L(5, 17), L(8, 17) + L(4, 13), L(5, 17), L(7, 13), L(8, 17) ) ], sort_list=False) @@ -378,9 +378,9 @@ def test_compares_deep_types_including_list(): ''', [ fields_conflict( 'edges', [['node', [['id', 'id and name are different fields']]]], - L(14, 9), L(5, 13), - L(15, 13), L(6, 17), - L(16, 17), L(7, 21), + L(14, 9), L(15, 13), + L(16, 17), L(5, 13), + L(6, 17), L(7, 21), ) ], sort_list=False) From a6424ab0d91e372f702057c300288c571f71a335 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 9 Apr 2016 03:22:39 -0700 Subject: [PATCH 03/16] [Validation] Allow safe divergence Related GraphQL commit https://github.com/graphql/graphql-js/tree/d71e063fdd1d4c376b4948147e54438b6f1e13de --- .../rules/overlapping_fields_can_be_merged.py | 26 ++++- .../test_overlapping_fields_can_be_merged.py | 110 +++++++++++++++--- 2 files changed, 113 insertions(+), 23 deletions(-) diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index acd560d2..d8dcc186 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -36,11 +36,27 @@ def find_conflicts(self, field_map): return conflicts - def find_conflict(self, response_name, pair1, pair2): - ast1, def1 = pair1 - ast2, def2 = pair2 + def find_conflict(self, response_name, field1, field2): + parent_type1, ast1, def1 = field1 + parent_type2, ast2, def2 = field2 - if ast1 is ast2 or self.compared_set.has(ast1, ast2): + # Not a pair + if ast1 is ast2: + return + + # If the statically known parent types could not possibly apply at the same + # time, then it is safe to permit them to diverge as they will not present + # any ambiguity by differing. + # It is known that two parent types could never overlap if they are + # different Object types. Interface or Union types might overlap - if not + # in the current state of the schema, then perhaps in some future version, + # thus may not safely diverge. + if parent_type1 != parent_type2 and \ + isinstance(parent_type1, GraphQLObjectType) and \ + isinstance(parent_type2, GraphQLObjectType): + return + + if self.compared_set.has(ast1, ast2): return self.compared_set.add(ast1, ast2) @@ -192,7 +208,7 @@ def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragme field_def = parent_type.get_fields().get(field_name) response_name = selection.alias.value if selection.alias else field_name - ast_and_defs[response_name].append((selection, field_def)) + ast_and_defs[response_name].append((parent_type, selection, field_def)) elif isinstance(selection, ast.InlineFragment): type_condition = selection.type_condition diff --git a/tests/core_validation/test_overlapping_fields_can_be_merged.py b/tests/core_validation/test_overlapping_fields_can_be_merged.py index ec353830..68bc42a9 100644 --- a/tests/core_validation/test_overlapping_fields_can_be_merged.py +++ b/tests/core_validation/test_overlapping_fields_can_be_merged.py @@ -1,5 +1,5 @@ from graphql.core.language.location import SourceLocation as L -from graphql.core.type.definition import GraphQLObjectType, GraphQLArgument, GraphQLNonNull, GraphQLUnionType, \ +from graphql.core.type.definition import GraphQLObjectType, GraphQLArgument, GraphQLNonNull, GraphQLInterfaceType, \ GraphQLList, GraphQLField from graphql.core.type.scalars import GraphQLString, GraphQLInt, GraphQLID from graphql.core.type.schema import GraphQLSchema @@ -79,6 +79,19 @@ def test_same_aliases_with_different_field_targets(): ], sort_list=False) +def test_same_aliases_allowed_on_nonoverlapping_fields(): + expect_passes_rule(OverlappingFieldsCanBeMerged, ''' + fragment sameAliasesWithDifferentFieldTargets on Pet { + ... on Dog { + name + } + ... on Cat { + name: nickname + } + } + ''') + + def test_alias_masking_direct_field_access(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' fragment aliasMaskingDirectFieldAccess on Dog { @@ -90,6 +103,28 @@ def test_alias_masking_direct_field_access(): ], sort_list=False) +def test_diferent_args_second_adds_an_argument(): + expect_fails_rule(OverlappingFieldsCanBeMerged, ''' + fragment conflictingArgs on Dog { + doesKnowCommand + doesKnowCommand(dogCommand: HEEL) + } + ''', [ + fields_conflict('doesKnowCommand', 'they have differing arguments', L(3, 9), L(4, 9)) + ], sort_list=False) + + +def test_diferent_args_second_missing_an_argument(): + expect_fails_rule(OverlappingFieldsCanBeMerged, ''' + fragment conflictingArgs on Dog { + doesKnowCommand(dogCommand: SIT) + doesKnowCommand + } + ''', [ + fields_conflict('doesKnowCommand', 'they have differing arguments', L(3, 9), L(4, 9)) + ], sort_list=False) + + def test_conflicting_args(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' fragment conflictingArgs on Dog { @@ -101,6 +136,18 @@ def test_conflicting_args(): ], sort_list=False) +def test_allows_different_args_where_no_conflict_is_possible(): + expect_passes_rule(OverlappingFieldsCanBeMerged, ''' + fragment conflictingArgs on Pet { + ... on Dog { + name(surname: true) + } + ... on Cat { + name + } + } + ''') + def test_conflicting_directives(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' fragment conflictingDirectiveArgs on Dog { @@ -276,25 +323,37 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): ], sort_list=False) +SomeBox = GraphQLInterfaceType('SomeBox', { + 'unrelatedField': GraphQLField(GraphQLString) +}, resolve_type=lambda *_: StringBox) + StringBox = GraphQLObjectType('StringBox', { - 'scalar': GraphQLField(GraphQLString) -}) + 'scalar': GraphQLField(GraphQLString), + 'unrelatedField': GraphQLField(GraphQLString) +}, interfaces=[SomeBox]) IntBox = GraphQLObjectType('IntBox', { - 'scalar': GraphQLField(GraphQLInt) -}) + 'scalar': GraphQLField(GraphQLInt), + 'unrelatedField': GraphQLField(GraphQLString) +}, interfaces=[SomeBox]) -NonNullStringBox1 = GraphQLObjectType('NonNullStringBox1', { +NonNullStringBox1 = GraphQLInterfaceType('NonNullStringBox1', { 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)) -}) +}, resolve_type=lambda *_: StringBox) + +NonNullStringBox1Impl = GraphQLObjectType('NonNullStringBox1Impl', { + 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)), + 'unrelatedField': GraphQLField(GraphQLString) +}, interfaces=[ SomeBox, NonNullStringBox1 ]) -NonNullStringBox2 = GraphQLObjectType('NonNullStringBox2', { +NonNullStringBox2 = GraphQLInterfaceType('NonNullStringBox2', { 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)) -}) +}, resolve_type=lambda *_: StringBox) -BoxUnion = GraphQLUnionType('BoxUnion', [ - StringBox, IntBox, NonNullStringBox1, NonNullStringBox2 -], resolve_type=lambda *_: StringBox) +NonNullStringBox2Impl = GraphQLObjectType('NonNullStringBox2Impl', { + 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)), + 'unrelatedField': GraphQLField(GraphQLString) +}, interfaces=[ SomeBox, NonNullStringBox2 ]) Connection = GraphQLObjectType('Connection', { 'edges': GraphQLField(GraphQLList(GraphQLObjectType('Edge', { @@ -306,33 +365,48 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): }) schema = GraphQLSchema(GraphQLObjectType('QueryRoot', { - 'boxUnion': GraphQLField(BoxUnion), + 'someBox': GraphQLField(SomeBox), 'connection': GraphQLField(Connection) })) -def test_conflicting_scalar_return_types(): +def test_conflicting_return_types_which_potentially_overlap(): expect_fails_rule_with_schema(schema, OverlappingFieldsCanBeMerged, ''' { - boxUnion { + someBox { ...on IntBox { scalar } - ...on StringBox { + ...on NonNullStringBox1 { scalar } } } ''', [ - fields_conflict('scalar', 'they return differing types Int and String', L(5, 17), L(8, 17)) + fields_conflict('scalar', 'they return differing types Int and String!', L(5, 17), L(8, 17)) ], sort_list=False) +def test_allows_differing_return_types_which_cannot_overlap(): + expect_passes_rule_with_schema(schema, OverlappingFieldsCanBeMerged, ''' + { + someBox { + ...on IntBox { + scalar + } + ...on StringBox { + scalar + } + } + } + ''') + + def test_same_wrapped_scalar_return_types(): expect_passes_rule_with_schema(schema, OverlappingFieldsCanBeMerged, ''' { - boxUnion { + someBox { ...on NonNullStringBox1 { scalar } From 5a972032ea9551409e8fb8d98683933f8cf4d966 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sat, 9 Apr 2016 03:28:50 -0700 Subject: [PATCH 04/16] [Validation] Allow differing directives. Related GraphQL-js commit https://github.com/graphql/graphql-js/commit/9b11df2efc66ad3c07e0d373a7e04a3ba5ee581a --- .../rules/overlapping_fields_can_be_merged.py | 29 ----------- .../test_overlapping_fields_can_be_merged.py | 52 ++++--------------- 2 files changed, 9 insertions(+), 72 deletions(-) diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index d8dcc186..86b46340 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -88,13 +88,6 @@ def find_conflict(self, response_name, field1, field2): [ast2] ) - if not self.same_directives(ast1.directives, ast2.directives): - return ( - (response_name, 'they have differing directives'), - [ast1], - [ast2] - ) - selection_set1 = ast1.selection_set selection_set2 = ast2.selection_set @@ -165,28 +158,6 @@ def same_arguments(cls, arguments1, arguments2): return True - @classmethod - def same_directives(cls, directives1, directives2): - # Check to see if they are empty directives or nones. If they are, we can - # bail out early. - if not (directives1 or directives2): - return True - - if len(directives1) != len(directives2): - return False - - directives2_values_to_arg = {a.name.value: a for a in directives2} - - for directive1 in directives1: - directive2 = directives2_values_to_arg.get(directive1.name.value) - if not directive2: - return False - - if not cls.same_arguments(directive1.arguments, directive2.arguments): - return False - - return True - def collect_field_asts_and_defs(self, parent_type, selection_set, visited_fragment_names=None, ast_and_defs=None): if visited_fragment_names is None: visited_fragment_names = set() diff --git a/tests/core_validation/test_overlapping_fields_can_be_merged.py b/tests/core_validation/test_overlapping_fields_can_be_merged.py index 68bc42a9..3fb42e8f 100644 --- a/tests/core_validation/test_overlapping_fields_can_be_merged.py +++ b/tests/core_validation/test_overlapping_fields_can_be_merged.py @@ -68,6 +68,15 @@ def test_different_directives_with_different_aliases(): ''') +def test_different_skip_or_include_directives_accepted(): + expect_passes_rule(OverlappingFieldsCanBeMerged, ''' + fragment differentDirectivesWithDifferentAliases on Dog { + name @include(if: true) + name @include(if: false) + } + ''') + + def test_same_aliases_with_different_field_targets(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' fragment sameAliasesWithDifferentFieldTargets on Dog { @@ -148,49 +157,6 @@ def test_allows_different_args_where_no_conflict_is_possible(): } ''') -def test_conflicting_directives(): - expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - fragment conflictingDirectiveArgs on Dog { - name @include(if: true) - name @skip(if: false) - } - ''', [ - fields_conflict('name', 'they have differing directives', L(3, 9), L(4, 9)) - ], sort_list=False) - - -def test_conflicting_directive_args(): - expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - fragment conflictingDirectiveArgs on Dog { - name @include(if: true) - name @include(if: false) - } - ''', [ - fields_conflict('name', 'they have differing directives', L(3, 9), L(4, 9)) - ], sort_list=False) - - -def test_conflicting_args_with_matching_directives(): - expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - fragment conflictingArgsWithMatchingDirectiveArgs on Dog { - doesKnowCommand(dogCommand: SIT) @include(if: true) - doesKnowCommand(dogCommand: HEEL) @include(if: true) - } - ''', [ - fields_conflict('doesKnowCommand', 'they have differing arguments', L(3, 9), L(4, 9)) - ], sort_list=False) - - -def test_conflicting_directives_with_matching_args(): - expect_fails_rule(OverlappingFieldsCanBeMerged, ''' - fragment conflictingDirectiveArgsWithMatchingArgs on Dog { - doesKnowCommand(dogCommand: SIT) @include(if: true) - doesKnowCommand(dogCommand: SIT) @skip(if: false) - } - ''', [ - fields_conflict('doesKnowCommand', 'they have differing directives', L(3, 9), L(4, 9)) - ], sort_list=False) - def test_encounters_conflict_in_fragments(): expect_fails_rule(OverlappingFieldsCanBeMerged, ''' From 5dc00fb3205ec98d149ac642d22c653280f9e9c1 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 11 Apr 2016 22:50:59 -0700 Subject: [PATCH 05/16] [Validation] Perf improvements for fragment cycle detection Related GraphQL-js commit https://github.com/graphql/graphql-js/commit/4cf2190b54fefc34a7b37d22a489933fc15e14ce --- .../validation/rules/no_fragment_cycles.py | 109 +++++++++--------- .../test_no_fragment_cycles.py | 32 ++++- 2 files changed, 82 insertions(+), 59 deletions(-) diff --git a/graphql/core/validation/rules/no_fragment_cycles.py b/graphql/core/validation/rules/no_fragment_cycles.py index 9d7b3900..b17a9a52 100644 --- a/graphql/core/validation/rules/no_fragment_cycles.py +++ b/graphql/core/validation/rules/no_fragment_cycles.py @@ -5,51 +5,60 @@ class NoFragmentCycles(ValidationRule): - __slots__ = 'spreads_in_fragment', 'known_to_lead_to_cycle' + __slots__ = 'errors', 'visited_frags', 'spread_path', 'spread_path_index_by_name' def __init__(self, context): super(NoFragmentCycles, self).__init__(context) - self.spreads_in_fragment = { - node.name.value: self.gather_spreads(node) - for node in context.get_ast().definitions - if isinstance(node, ast.FragmentDefinition) - } - self.known_to_lead_to_cycle = set() + self.errors = [] + self.visited_frags = set() + self.spread_path = [] + self.spread_path_index_by_name = {} - def enter_FragmentDefinition(self, node, key, parent, path, ancestors): - errors = [] - initial_name = node.name.value - spread_path = [] - - def detect_cycle_recursive(fragment_name): - spread_nodes = self.spreads_in_fragment.get(fragment_name) - if not spread_nodes: - return - - for spread_node in spread_nodes: - if spread_node in self.known_to_lead_to_cycle: - continue - - if spread_node.name.value == initial_name: - cycle_path = spread_path + [spread_node] - self.known_to_lead_to_cycle |= set(cycle_path) - - errors.append(GraphQLError( - self.cycle_error_message(initial_name, [s.name.value for s in spread_path]), - cycle_path - )) - continue - - if any(spread is spread_node for spread in spread_path): - continue + def leave_Document(self, node, key, parent, path, ancestors): + if self.errors: + return self.errors - spread_path.append(spread_node) - detect_cycle_recursive(spread_node.name.value) - spread_path.pop() + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + return False - detect_cycle_recursive(initial_name) - if errors: - return errors + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + if node.name.value not in self.visited_frags: + self.detect_cycle_recursive(node) + return False + + def detect_cycle_recursive(self, fragment): + fragment_name = fragment.name.value + self.visited_frags.add(fragment_name) + + spread_nodes = [] + self.gather_spreads(spread_nodes, fragment.selection_set) + if not spread_nodes: + return + + self.spread_path_index_by_name[fragment_name] = len(self.spread_path) + + for spread_node in spread_nodes: + spread_name = spread_node.name.value + cycle_index = self.spread_path_index_by_name.get(spread_name) + + if cycle_index is None: + self.spread_path.append(spread_node) + if spread_name not in self.visited_frags: + spread_fragment = self.context.get_fragment(spread_name) + if spread_fragment: + self.detect_cycle_recursive(spread_fragment) + self.spread_path.pop() + else: + cycle_path = self.spread_path[cycle_index:] + self.errors.append(GraphQLError( + self.cycle_error_message( + spread_name, + [s.name.value for s in cycle_path] + ), + cycle_path+[spread_node] + )) + + self.spread_path_index_by_name[fragment_name] = None @staticmethod def cycle_error_message(fragment_name, spread_names): @@ -57,19 +66,9 @@ def cycle_error_message(fragment_name, spread_names): return 'Cannot spread fragment "{}" within itself{}.'.format(fragment_name, via) @classmethod - def gather_spreads(cls, node): - visitor = cls.CollectFragmentSpreadNodesVisitor() - visit(node, visitor) - return visitor.collect_fragment_spread_nodes() - - class CollectFragmentSpreadNodesVisitor(Visitor): - __slots__ = 'spread_nodes', - - def __init__(self): - self.spread_nodes = [] - - def enter_FragmentSpread(self, node, key, parent, path, ancestors): - self.spread_nodes.append(node) - - def collect_fragment_spread_nodes(self): - return self.spread_nodes + def gather_spreads(cls, spreads, node): + for selection in node.selections: + if isinstance(selection, ast.FragmentSpread): + spreads.append(selection) + elif selection.selection_set: + cls.gather_spreads(spreads, selection.selection_set) diff --git a/tests/core_validation/test_no_fragment_cycles.py b/tests/core_validation/test_no_fragment_cycles.py index dcfa59ff..d7269060 100644 --- a/tests/core_validation/test_no_fragment_cycles.py +++ b/tests/core_validation/test_no_fragment_cycles.py @@ -124,14 +124,15 @@ def test_no_spreading_itself_deeply(): fragment fragX on Dog { ...fragY } fragment fragY on Dog { ...fragZ } fragment fragZ on Dog { ...fragO } - fragment fragO on Dog { ...fragA, ...fragX } + fragment fragO on Dog { ...fragP } + fragment fragP on Dog { ...fragA, ...fragX } ''', [ - cycle_error_message('fragA', ['fragB', 'fragC', 'fragO'], L(2, 29), L(3, 29), L(4, 29), L(8, 29)), - cycle_error_message('fragX', ['fragY', 'fragZ', 'fragO'], L(5, 29), L(6, 29), L(7, 29), L(8, 39)) + cycle_error_message('fragA', ['fragB', 'fragC', 'fragO', 'fragP'], L(2, 29), L(3, 29), L(4, 29), L(8, 29), L(9, 29)), + cycle_error_message('fragO', ['fragP', 'fragX', 'fragY', 'fragZ'], L(8, 29), L(9, 39), L(5, 29), L(6, 29), L(7, 29)) ]) -def test_no_spreading_itself_deeply_two_paths(): # -- new rule +def test_no_spreading_itself_deeply_two_paths(): expect_fails_rule(NoFragmentCycles, ''' fragment fragA on Dog { ...fragB, ...fragC } fragment fragB on Dog { ...fragA } @@ -140,3 +141,26 @@ def test_no_spreading_itself_deeply_two_paths(): # -- new rule cycle_error_message('fragA', ['fragB'], L(2, 29), L(3, 29)), cycle_error_message('fragA', ['fragC'], L(2, 39), L(4, 29)) ]) + + +def test_no_spreading_itself_deeply_two_paths_alt_reverse_order(): + expect_fails_rule(NoFragmentCycles, ''' + fragment fragA on Dog { ...fragC } + fragment fragB on Dog { ...fragC } + fragment fragC on Dog { ...fragA, ...fragB } + ''', [ + cycle_error_message('fragA', ['fragC'], L(2, 29), L(4, 29)), + cycle_error_message('fragC', ['fragB'], L(4, 39), L(3, 29)) + ]) + + +def test_no_spreading_itself_deeply_and_immediately(): + expect_fails_rule(NoFragmentCycles, ''' + fragment fragA on Dog { ...fragB } + fragment fragB on Dog { ...fragB, ...fragC } + fragment fragC on Dog { ...fragA, ...fragB } + ''', [ + cycle_error_message('fragB', [], L(3, 29)), + cycle_error_message('fragA', ['fragB', 'fragC'], L(2, 29), L(3, 39), L(4, 29)), + cycle_error_message('fragB', ['fragC'], L(3, 39), L(4, 39)) + ]) From 7c3e769c1fe95553cd661a4bc8ea755ffd44fb1a Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 11 Apr 2016 23:10:20 -0700 Subject: [PATCH 06/16] [Validation] Performance improvements Related GraphQL-js commit https://github.com/graphql/graphql-js/commit/0bc9088187b9902ab19c0ec34e0e9f036dc9d9ea --- graphql/core/validation/context.py | 21 +++++++++++++-- .../rules/arguments_of_correct_type.py | 1 + .../rules/default_values_of_correct_type.py | 7 +++++ .../validation/rules/no_fragment_cycles.py | 11 +------- .../validation/rules/no_unused_fragments.py | 26 ++++++++----------- .../validation/rules/unique_argument_names.py | 1 + .../validation/rules/unique_fragment_names.py | 4 +++ .../rules/unique_input_field_names.py | 1 + .../rules/unique_operation_names.py | 4 +++ 9 files changed, 49 insertions(+), 27 deletions(-) diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index c20fa52d..ec881af7 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -1,18 +1,27 @@ -from ..language.ast import FragmentDefinition +from ..language.ast import FragmentDefinition, FragmentSpread class ValidationContext(object): - __slots__ = '_schema', '_ast', '_type_info', '_fragments' + __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads' def __init__(self, schema, ast, type_info): self._schema = schema self._ast = ast self._type_info = type_info self._fragments = None + self._fragment_spreads = {} def get_schema(self): return self._schema + def get_fragment_spreads(self, node): + spreads = self._fragment_spreads.get(node) + if not spreads: + spreads = [] + self.gather_spreads(spreads, node.selection_set) + self._fragment_spreads[node] = spreads + return spreads + def get_ast(self): return self._ast @@ -42,3 +51,11 @@ def get_directive(self): def get_argument(self): return self._type_info.get_argument() + + @classmethod + def gather_spreads(cls, spreads, node): + for selection in node.selections: + if isinstance(selection, FragmentSpread): + spreads.append(selection) + elif selection.selection_set: + cls.gather_spreads(spreads, selection.selection_set) diff --git a/graphql/core/validation/rules/arguments_of_correct_type.py b/graphql/core/validation/rules/arguments_of_correct_type.py index b6b50403..dbfd1aaf 100644 --- a/graphql/core/validation/rules/arguments_of_correct_type.py +++ b/graphql/core/validation/rules/arguments_of_correct_type.py @@ -15,6 +15,7 @@ def enter_Argument(self, node, key, parent, path, ancestors): print_ast(node.value), errors), [node.value] ) + return False @staticmethod def bad_value_message(arg_name, type, value, verbose_errors): diff --git a/graphql/core/validation/rules/default_values_of_correct_type.py b/graphql/core/validation/rules/default_values_of_correct_type.py index b401afc2..15764423 100644 --- a/graphql/core/validation/rules/default_values_of_correct_type.py +++ b/graphql/core/validation/rules/default_values_of_correct_type.py @@ -24,6 +24,13 @@ def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.bad_value_for_default_arg_message(name, type, print_ast(default_value), errors), [default_value] ) + return False + + def enter_SelectionSet(self, node, key, parent, path, ancestors): + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + return False @staticmethod def default_for_non_null_arg_message(var_name, type, guess_type): diff --git a/graphql/core/validation/rules/no_fragment_cycles.py b/graphql/core/validation/rules/no_fragment_cycles.py index b17a9a52..069f0556 100644 --- a/graphql/core/validation/rules/no_fragment_cycles.py +++ b/graphql/core/validation/rules/no_fragment_cycles.py @@ -30,8 +30,7 @@ def detect_cycle_recursive(self, fragment): fragment_name = fragment.name.value self.visited_frags.add(fragment_name) - spread_nodes = [] - self.gather_spreads(spread_nodes, fragment.selection_set) + spread_nodes = self.context.get_fragment_spreads(fragment) if not spread_nodes: return @@ -64,11 +63,3 @@ def detect_cycle_recursive(self, fragment): def cycle_error_message(fragment_name, spread_names): via = ' via {}'.format(', '.join(spread_names)) if spread_names else '' return 'Cannot spread fragment "{}" within itself{}.'.format(fragment_name, via) - - @classmethod - def gather_spreads(cls, spreads, node): - for selection in node.selections: - if isinstance(selection, ast.FragmentSpread): - spreads.append(selection) - elif selection.selection_set: - cls.gather_spreads(spreads, selection.selection_set) diff --git a/graphql/core/validation/rules/no_unused_fragments.py b/graphql/core/validation/rules/no_unused_fragments.py index f3050efb..7ef7c861 100644 --- a/graphql/core/validation/rules/no_unused_fragments.py +++ b/graphql/core/validation/rules/no_unused_fragments.py @@ -7,34 +7,30 @@ class NoUnusedFragments(ValidationRule): def __init__(self, context): super(NoUnusedFragments, self).__init__(context) - self.fragment_definitions = [] self.spreads_within_operation = [] - self.fragment_adjacencies = {} - self.spread_names = set() + self.fragment_definitions = [] def enter_OperationDefinition(self, node, key, parent, path, ancestors): - self.spread_names = set() - self.spreads_within_operation.append(self.spread_names) + self.spreads_within_operation.append(self.context.get_fragment_spreads(node)) + return False def enter_FragmentDefinition(self, node, key, parent, path, ancestors): self.fragment_definitions.append(node) - self.spread_names = set() - self.fragment_adjacencies[node.name.value] = self.spread_names - - def enter_FragmentSpread(self, node, key, parent, path, ancestors): - self.spread_names.add(node.name.value) + return False def leave_Document(self, node, key, parent, path, ancestors): fragment_names_used = set() def reduce_spread_fragments(spreads): - for fragment_name in spreads: - if fragment_name in fragment_names_used: + for spread in spreads: + frag_name = spread.name.value + if frag_name in fragment_names_used: continue - fragment_names_used.add(fragment_name) - if fragment_name in self.fragment_adjacencies: - reduce_spread_fragments(self.fragment_adjacencies[fragment_name]) + fragment_names_used.add(frag_name) + fragment = self.context.get_fragment(frag_name) + if fragment: + reduce_spread_fragments(self.context.get_fragment_spreads(fragment)) for spreads in self.spreads_within_operation: reduce_spread_fragments(spreads) diff --git a/graphql/core/validation/rules/unique_argument_names.py b/graphql/core/validation/rules/unique_argument_names.py index 71c0101d..90c79ebf 100644 --- a/graphql/core/validation/rules/unique_argument_names.py +++ b/graphql/core/validation/rules/unique_argument_names.py @@ -25,6 +25,7 @@ def enter_Argument(self, node, key, parent, path, ancestors): ) self.known_arg_names[arg_name] = node.name + return False @staticmethod def duplicate_arg_message(field): diff --git a/graphql/core/validation/rules/unique_fragment_names.py b/graphql/core/validation/rules/unique_fragment_names.py index 3b2ecea3..d36e9254 100644 --- a/graphql/core/validation/rules/unique_fragment_names.py +++ b/graphql/core/validation/rules/unique_fragment_names.py @@ -9,6 +9,9 @@ def __init__(self, context): super(UniqueFragmentNames, self).__init__(context) self.known_fragment_names = {} + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + return False + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): fragment_name = node.name.value if fragment_name in self.known_fragment_names: @@ -18,6 +21,7 @@ def enter_FragmentDefinition(self, node, key, parent, path, ancestors): ) self.known_fragment_names[fragment_name] = node.name + return False @staticmethod def duplicate_fragment_name_message(field): diff --git a/graphql/core/validation/rules/unique_input_field_names.py b/graphql/core/validation/rules/unique_input_field_names.py index 594c70a3..534fe688 100644 --- a/graphql/core/validation/rules/unique_input_field_names.py +++ b/graphql/core/validation/rules/unique_input_field_names.py @@ -26,6 +26,7 @@ def enter_ObjectField(self, node, key, parent, path, ancestors): ) self.known_names[field_name] = node.name + return False @staticmethod def duplicate_input_field_message(field_name): diff --git a/graphql/core/validation/rules/unique_operation_names.py b/graphql/core/validation/rules/unique_operation_names.py index 52565d2d..07893395 100644 --- a/graphql/core/validation/rules/unique_operation_names.py +++ b/graphql/core/validation/rules/unique_operation_names.py @@ -21,6 +21,10 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): ) self.known_operation_names[operation_name.value] = operation_name + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + return False @staticmethod def duplicate_operation_name_message(operation_name): From 665150725232d913c2cf7332198c2987bc9ba237 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Mon, 11 Apr 2016 23:25:12 -0700 Subject: [PATCH 07/16] [Validation] Factor out and memoize recursively referenced fragments. Related GraphQL commit: https://github.com/graphql/graphql-js/commit/eef8d97f64b5b5fa0df79435c4fe237976867573 --- graphql/core/validation/context.py | 23 ++++++++++++++++++- .../validation/rules/no_unused_fragments.py | 23 ++++++------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index ec881af7..11ac7cf5 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -2,7 +2,7 @@ class ValidationContext(object): - __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads' + __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments' def __init__(self, schema, ast, type_info): self._schema = schema @@ -10,10 +10,31 @@ def __init__(self, schema, ast, type_info): self._type_info = type_info self._fragments = None self._fragment_spreads = {} + self._recursively_referenced_fragments = {} def get_schema(self): return self._schema + def get_recursively_referenced_fragments(self, operation): + fragments = self._recursively_referenced_fragments.get(operation) + if not fragments: + fragments = [] + collected_names = set() + nodes_to_visit = [operation] + while nodes_to_visit: + node = nodes_to_visit.pop() + spreads = self.get_fragment_spreads(node) + for spread in spreads: + frag_name = spread.name.value + if frag_name not in collected_names: + collected_names.add(frag_name) + fragment = self.get_fragment(frag_name) + if fragment: + fragments.append(fragment) + nodes_to_visit.append(fragment) + self._recursively_referenced_fragments[operation] = fragments + return fragments + def get_fragment_spreads(self, node): spreads = self._fragment_spreads.get(node) if not spreads: diff --git a/graphql/core/validation/rules/no_unused_fragments.py b/graphql/core/validation/rules/no_unused_fragments.py index 7ef7c861..3b3c5041 100644 --- a/graphql/core/validation/rules/no_unused_fragments.py +++ b/graphql/core/validation/rules/no_unused_fragments.py @@ -3,15 +3,15 @@ class NoUnusedFragments(ValidationRule): - __slots__ = 'fragment_definitions', 'spreads_within_operation', 'fragment_adjacencies', 'spread_names' + __slots__ = 'fragment_definitions', 'operation_definitions', 'fragment_adjacencies', 'spread_names' def __init__(self, context): super(NoUnusedFragments, self).__init__(context) - self.spreads_within_operation = [] + self.operation_definitions = [] self.fragment_definitions = [] def enter_OperationDefinition(self, node, key, parent, path, ancestors): - self.spreads_within_operation.append(self.context.get_fragment_spreads(node)) + self.operation_definitions.append(node) return False def enter_FragmentDefinition(self, node, key, parent, path, ancestors): @@ -21,19 +21,10 @@ def enter_FragmentDefinition(self, node, key, parent, path, ancestors): def leave_Document(self, node, key, parent, path, ancestors): fragment_names_used = set() - def reduce_spread_fragments(spreads): - for spread in spreads: - frag_name = spread.name.value - if frag_name in fragment_names_used: - continue - - fragment_names_used.add(frag_name) - fragment = self.context.get_fragment(frag_name) - if fragment: - reduce_spread_fragments(self.context.get_fragment_spreads(fragment)) - - for spreads in self.spreads_within_operation: - reduce_spread_fragments(spreads) + for operation in self.operation_definitions: + fragments = self.context.get_recursively_referenced_fragments(operation) + for fragment in fragments: + fragment_names_used.add(fragment.name.value) errors = [ GraphQLError( From ccf167cf50c6b490354f74041ecd62e875699fef Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 12 Apr 2016 22:38:57 -0700 Subject: [PATCH 08/16] [Validation] Memoize collecting variable usage. Related GraphQL-js commit: https://github.com/graphql/graphql-js/commit/2afbff79bfd2b89f03ca7913577556b73980f974 --- graphql/core/validation/context.py | 58 ++++++++++++++++++- .../rules/no_undefined_variables.py | 57 +++++++----------- .../validation/rules/no_unused_variables.py | 36 ++++-------- .../rules/variables_in_allowed_position.py | 39 +++++++------ .../test_no_undefined_variables.py | 52 +++++++---------- 5 files changed, 131 insertions(+), 111 deletions(-) diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index 11ac7cf5..607a8dfd 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -1,8 +1,37 @@ -from ..language.ast import FragmentDefinition, FragmentSpread +from ..language.ast import FragmentDefinition, FragmentSpread, VariableDefinition, Variable, OperationDefinition +from ..utils.type_info import TypeInfo +from ..language.visitor import Visitor, visit + + +class VariableUsage(object): + __slots__ = 'node', 'type' + + def __init__(self, node, type): + self.node = node + self.type = type + + +class UsageVisitor(Visitor): + __slots__ = 'context', 'usages', 'type_info' + + def __init__(self, usages, type_info): + self.usages = usages + self.type_info = type_info + + def enter(self, node, key, parent, path, ancestors): + self.type_info.enter(node) + if isinstance(node, VariableDefinition): + return False + elif isinstance(node, Variable): + usage = VariableUsage(node, type=self.type_info.get_input_type()) + self.usages.append(usage) + + def leave(self, node, key, parent, path, ancestors): + self.type_info.leave(node) class ValidationContext(object): - __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments' + __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages' def __init__(self, schema, ast, type_info): self._schema = schema @@ -11,11 +40,36 @@ def __init__(self, schema, ast, type_info): self._fragments = None self._fragment_spreads = {} self._recursively_referenced_fragments = {} + self._variable_usages = {} + self._recursive_variable_usages = {} def get_schema(self): return self._schema + def get_variable_usages(self, node): + usages = self._variable_usages.get(node) + if usages is None: + usages = [] + sub_visitor = UsageVisitor(usages, self._type_info) + visit(node, sub_visitor) + self._variable_usages[node] = usages + + return usages + + def get_recursive_variable_usages(self, operation): + assert isinstance(operation, OperationDefinition) + usages = self._recursive_variable_usages.get(operation) + if usages is None: + usages = self.get_variable_usages(operation) + fragments = self.get_recursively_referenced_fragments(operation) + for fragment in fragments: + usages.extend(self.get_variable_usages(fragment)) + self._recursive_variable_usages[operation] = usages + + return usages + def get_recursively_referenced_fragments(self, operation): + assert isinstance(operation, OperationDefinition) fragments = self._recursively_referenced_fragments.get(operation) if not fragments: fragments = [] diff --git a/graphql/core/validation/rules/no_undefined_variables.py b/graphql/core/validation/rules/no_undefined_variables.py index 6882bce0..bde1e572 100644 --- a/graphql/core/validation/rules/no_undefined_variables.py +++ b/graphql/core/validation/rules/no_undefined_variables.py @@ -4,51 +4,38 @@ class NoUndefinedVariables(ValidationRule): - __slots__ = 'visited_fragment_names', 'defined_variable_names', 'operation', - visit_spread_fragments = True + __slots__ = 'defined_variable_names', def __init__(self, context): - self.visited_fragment_names = set() self.defined_variable_names = set() - self.operation = None - super(NoUndefinedVariables, self).__init__(context) @staticmethod - def undefined_var_message(var_name): + def undefined_var_message(var_name, op_name=None): + if op_name: + return 'Variable "${}" is not defined by operation "{}".'.format( + var_name, op_name + ) return 'Variable "${}" is not defined.'.format(var_name) - @staticmethod - def undefined_var_by_op_message(var_name, op_name): - return 'Variable "${}" is not defined by operation "{}".'.format( - var_name, op_name - ) - - def enter_OperationDefinition(self, node, key, parent, path, ancestors): - self.operation = node - self.visited_fragment_names = set() + def enter_OperationDefinition(self, operation, key, parent, path, ancestors): self.defined_variable_names = set() - def enter_VariableDefinition(self, node, key, parent, path, ancestors): - self.defined_variable_names.add(node.variable.name.value) + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + usages = self.context.get_recursive_variable_usages(operation) + errors = [] - def enter_Variable(self, variable, key, parent, path, ancestors): - var_name = variable.name.value - if var_name not in self.defined_variable_names: - within_fragment = any(isinstance(node, ast.FragmentDefinition) for node in ancestors) - if within_fragment and self.operation and self.operation.name: - return GraphQLError( - self.undefined_var_by_op_message(var_name, self.operation.name.value), - [variable, self.operation] - ) - - return GraphQLError( - self.undefined_var_message(var_name), - [variable] - ) + for variable_usage in usages: + node = variable_usage.node + var_name = node.name.value + if var_name not in self.defined_variable_names: + errors.append(GraphQLError( + self.undefined_var_message(var_name, operation.name and operation.name.value), + [node, operation] + )) - def enter_FragmentSpread(self, spread_ast, *args): - if spread_ast.name.value in self.visited_fragment_names: - return False + if errors: + return errors - self.visited_fragment_names.add(spread_ast.name.value) + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + self.defined_variable_names.add(node.variable.name.value) diff --git a/graphql/core/validation/rules/no_unused_variables.py b/graphql/core/validation/rules/no_unused_variables.py index 1c87bf72..c529837d 100644 --- a/graphql/core/validation/rules/no_unused_variables.py +++ b/graphql/core/validation/rules/no_unused_variables.py @@ -3,50 +3,36 @@ class NoUnusedVariables(ValidationRule): - __slots__ = 'visited_fragment_names', 'variable_definitions', 'variable_name_used', 'visit_spread_fragments' + __slots__ = 'variable_definitions' def __init__(self, context): - self.visited_fragment_names = None - self.variable_definitions = None - self.variable_name_used = None - self.visit_spread_fragments = True + self.variable_definitions = [] super(NoUnusedVariables, self).__init__(context) def enter_OperationDefinition(self, node, key, parent, path, ancestors): - self.visited_fragment_names = set() self.variable_definitions = [] - self.variable_name_used = set() - def leave_OperationDefinition(self, node, key, parent, path, ancestors): + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + variable_name_used = set() + usages = self.context.get_recursive_variable_usages(operation) + + for variable_usage in usages: + variable_name_used.add(variable_usage.node.name.value) + errors = [ GraphQLError( self.unused_variable_message(variable_definition.variable.name.value), [variable_definition] ) for variable_definition in self.variable_definitions - if variable_definition.variable.name.value not in self.variable_name_used + if variable_definition.variable.name.value not in variable_name_used ] if errors: return errors def enter_VariableDefinition(self, node, key, parent, path, ancestors): - if self.variable_definitions is not None: - self.variable_definitions.append(node) - - return False - - def enter_Variable(self, node, key, parent, path, ancestors): - if self.variable_name_used is not None: - self.variable_name_used.add(node.name.value) - - def enter_FragmentSpread(self, node, key, parent, path, ancestors): - if self.visited_fragment_names is not None: - spread_name = node.name.value - if spread_name in self.visited_fragment_names: - return False - - self.visited_fragment_names.add(spread_name) + self.variable_definitions.append(node) @staticmethod def unused_variable_message(variable_name): diff --git a/graphql/core/validation/rules/variables_in_allowed_position.py b/graphql/core/validation/rules/variables_in_allowed_position.py index fb537ccc..b4ecd884 100644 --- a/graphql/core/validation/rules/variables_in_allowed_position.py +++ b/graphql/core/validation/rules/variables_in_allowed_position.py @@ -8,36 +8,37 @@ class VariablesInAllowedPosition(ValidationRule): - visit_spread_fragments = True - __slots__ = 'var_def_map', 'visited_fragment_names' + __slots__ = 'var_def_map' def __init__(self, context): super(VariablesInAllowedPosition, self).__init__(context) self.var_def_map = {} - self.visited_fragment_names = set() def enter_OperationDefinition(self, node, key, parent, path, ancestors): self.var_def_map = {} - self.visited_fragment_names = set() + + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + usages = self.context.get_recursive_variable_usages(operation) + errors = [] + + for usage in usages: + node = usage.node + type = usage.type + var_name = node.name.value + var_def = self.var_def_map.get(var_name) + var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type) + if var_type and type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), type): + errors.append(GraphQLError( + self.bad_var_pos_message(var_name, var_type, type), + [node] + )) + + if errors: + return errors def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.var_def_map[node.variable.name.value] = node - def enter_Variable(self, node, key, parent, path, ancestors): - var_name = node.name.value - var_def = self.var_def_map.get(var_name) - var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type) - input_type = self.context.get_input_type() - if var_type and input_type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), - input_type): - return GraphQLError(self.bad_var_pos_message(var_name, var_type, input_type), - [node]) - - def enter_FragmentSpread(self, node, key, parent, path, ancestors): - if node.name.value in self.visited_fragment_names: - return False - self.visited_fragment_names.add(node.name.value) - @staticmethod def effective_type(var_type, var_def): if not var_def.default_value or isinstance(var_type, GraphQLNonNull): diff --git a/tests/core_validation/test_no_undefined_variables.py b/tests/core_validation/test_no_undefined_variables.py index 4c9227ca..39392b4b 100644 --- a/tests/core_validation/test_no_undefined_variables.py +++ b/tests/core_validation/test_no_undefined_variables.py @@ -3,17 +3,9 @@ from utils import expect_passes_rule, expect_fails_rule -def undefined_var(var_name, line, column): +def undefined_var(var_name, l1, c1, op_name, l2, c2): return { - 'message': NoUndefinedVariables.undefined_var_message(var_name), - 'locations': [SourceLocation(line, column)] - } - - -def undefined_var_by_op(var_name, l1, c1, op_name, l2, c2): - return { - 'message': NoUndefinedVariables.undefined_var_by_op_message( - var_name, op_name), + 'message': NoUndefinedVariables.undefined_var_message(var_name, op_name), 'locations': [ SourceLocation(l1, c1), SourceLocation(l2, c2), @@ -128,7 +120,7 @@ def test_variable_not_defined(): field(a: $a, b: $b, c: $c, d: $d) } ''', [ - undefined_var('d', 3, 39) + undefined_var('d', 3, 39, 'Foo', 2, 7) ]) @@ -138,7 +130,7 @@ def variable_not_defined_by_unnamed_query(): field(a: $a) } ''', [ - undefined_var('a', 3, 18) + undefined_var('a', 3, 18, '', 2, 7) ]) @@ -148,8 +140,8 @@ def test_multiple_variables_not_defined(): field(a: $a, b: $b, c: $c) } ''', [ - undefined_var('a', 3, 18), - undefined_var('c', 3, 32) + undefined_var('a', 3, 18, 'Foo', 2, 7), + undefined_var('c', 3, 32, 'Foo', 2, 7) ]) @@ -162,7 +154,7 @@ def test_variable_in_fragment_not_defined_by_unnamed_query(): field(a: $a) } ''', [ - undefined_var('a', 6, 18) + undefined_var('a', 6, 18, '', 2, 7) ]) @@ -185,7 +177,7 @@ def test_variable_in_fragment_not_defined_by_operation(): field(c: $c) } ''', [ - undefined_var_by_op('c', 16, 18, 'Foo', 2, 7) + undefined_var('c', 16, 18, 'Foo', 2, 7) ]) @@ -208,8 +200,8 @@ def test_multiple_variables_in_fragments_not_defined(): field(c: $c) } ''', [ - undefined_var_by_op('a', 6, 18, 'Foo', 2, 7), - undefined_var_by_op('c', 16, 18, 'Foo', 2, 7) + undefined_var('a', 6, 18, 'Foo', 2, 7), + undefined_var('c', 16, 18, 'Foo', 2, 7) ]) @@ -225,8 +217,8 @@ def test_single_variable_in_fragment_not_defined_by_multiple_operations(): field(a: $a, b: $b) } ''', [ - undefined_var_by_op('b', 9, 25, 'Foo', 2, 7), - undefined_var_by_op('b', 9, 25, 'Bar', 5, 7) + undefined_var('b', 9, 25, 'Foo', 2, 7), + undefined_var('b', 9, 25, 'Bar', 5, 7) ]) @@ -242,8 +234,8 @@ def test_variables_in_fragment_not_defined_by_multiple_operations(): field(a: $a, b: $b) } ''', [ - undefined_var_by_op('a', 9, 18, 'Foo', 2, 7), - undefined_var_by_op('b', 9, 25, 'Bar', 5, 7) + undefined_var('a', 9, 18, 'Foo', 2, 7), + undefined_var('b', 9, 25, 'Bar', 5, 7) ]) @@ -262,8 +254,8 @@ def test_variable_in_fragment_used_by_other_operation(): field(b: $b) } ''', [ - undefined_var_by_op('a', 9, 18, 'Foo', 2, 7), - undefined_var_by_op('b', 12, 18, 'Bar', 5, 7) + undefined_var('a', 9, 18, 'Foo', 2, 7), + undefined_var('b', 12, 18, 'Bar', 5, 7) ]) @@ -284,10 +276,10 @@ def test_multiple_undefined_variables_produce_multiple_errors(): field2(c: $c) } ''', [ - undefined_var_by_op('a', 9, 19, 'Foo', 2, 7), - undefined_var_by_op('c', 14, 19, 'Foo', 2, 7), - undefined_var_by_op('a', 11, 19, 'Foo', 2, 7), - undefined_var_by_op('b', 9, 26, 'Bar', 5, 7), - undefined_var_by_op('c', 14, 19, 'Bar', 5, 7), - undefined_var_by_op('b', 11, 26, 'Bar', 5, 7), + undefined_var('a', 9, 19, 'Foo', 2, 7), + undefined_var('a', 11, 19, 'Foo', 2, 7), + undefined_var('c', 14, 19, 'Foo', 2, 7), + undefined_var('b', 9, 26, 'Bar', 5, 7), + undefined_var('b', 11, 26, 'Bar', 5, 7), + undefined_var('c', 14, 19, 'Bar', 5, 7), ]) From 891d4a66c3281fb985155d9cfef0698342cd1364 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 12 Apr 2016 23:05:32 -0700 Subject: [PATCH 09/16] [Validation] Report errors rather than return them Related GraphQL-js commit https://github.com/graphql/graphql-js/commit/5e545cce0104708a4ac6e994dd5f837d1d30a09b --- graphql/core/validation/__init__.py | 2 +- graphql/core/validation/context.py | 9 ++++++++- .../rules/arguments_of_correct_type.py | 4 ++-- .../rules/default_values_of_correct_type.py | 8 ++++---- .../rules/fields_on_correct_type.py | 4 ++-- .../rules/fragments_on_composite_types.py | 8 ++++---- .../validation/rules/known_argument_names.py | 8 ++++---- .../core/validation/rules/known_directives.py | 20 +++++++++---------- .../validation/rules/known_fragment_names.py | 4 ++-- .../core/validation/rules/known_type_names.py | 2 +- .../rules/lone_anonymous_operation.py | 2 +- .../validation/rules/no_fragment_cycles.py | 6 +----- .../rules/no_undefined_variables.py | 6 +----- .../validation/rules/no_unused_fragments.py | 17 ++++++---------- .../validation/rules/no_unused_variables.py | 17 ++++++---------- .../rules/overlapping_fields_can_be_merged.py | 6 ++---- .../rules/possible_fragment_spreads.py | 8 ++++---- .../rules/provided_non_null_arguments.py | 12 ++--------- graphql/core/validation/rules/scalar_leafs.py | 8 ++++---- .../validation/rules/unique_argument_names.py | 8 ++++---- .../validation/rules/unique_fragment_names.py | 8 ++++---- .../rules/unique_input_field_names.py | 8 ++++---- .../rules/unique_operation_names.py | 8 ++++---- .../rules/variables_are_input_types.py | 4 ++-- .../rules/variables_in_allowed_position.py | 6 +----- .../test_fields_on_correct_type.py | 17 +++++++++++++++- tests/core_validation/test_validation.py | 6 ++++-- 27 files changed, 104 insertions(+), 112 deletions(-) diff --git a/graphql/core/validation/__init__.py b/graphql/core/validation/__init__.py index 183c2925..bac1016d 100644 --- a/graphql/core/validation/__init__.py +++ b/graphql/core/validation/__init__.py @@ -19,4 +19,4 @@ def visit_using_rules(schema, type_info, ast, rules): errors = [] rules = [rule(context) for rule in rules] visit(ast, ValidationVisitor(rules, context, type_info, errors)) - return errors + return context.get_errors() diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index 607a8dfd..20021cc7 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -31,18 +31,25 @@ def leave(self, node, key, parent, path, ancestors): class ValidationContext(object): - __slots__ = '_schema', '_ast', '_type_info', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages' + __slots__ = '_schema', '_ast', '_type_info', '_errors', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages' def __init__(self, schema, ast, type_info): self._schema = schema self._ast = ast self._type_info = type_info + self._errors = [] self._fragments = None self._fragment_spreads = {} self._recursively_referenced_fragments = {} self._variable_usages = {} self._recursive_variable_usages = {} + def report_error(self, error): + self._errors.append(error) + + def get_errors(self): + return self._errors + def get_schema(self): return self._schema diff --git a/graphql/core/validation/rules/arguments_of_correct_type.py b/graphql/core/validation/rules/arguments_of_correct_type.py index dbfd1aaf..bc0f2dfe 100644 --- a/graphql/core/validation/rules/arguments_of_correct_type.py +++ b/graphql/core/validation/rules/arguments_of_correct_type.py @@ -10,11 +10,11 @@ def enter_Argument(self, node, key, parent, path, ancestors): if arg_def: errors = is_valid_literal_value(arg_def.type, node.value) if errors: - return GraphQLError( + self.context.report_error(GraphQLError( self.bad_value_message(node.name.value, arg_def.type, print_ast(node.value), errors), [node.value] - ) + )) return False @staticmethod diff --git a/graphql/core/validation/rules/default_values_of_correct_type.py b/graphql/core/validation/rules/default_values_of_correct_type.py index 15764423..98568362 100644 --- a/graphql/core/validation/rules/default_values_of_correct_type.py +++ b/graphql/core/validation/rules/default_values_of_correct_type.py @@ -12,18 +12,18 @@ def enter_VariableDefinition(self, node, key, parent, path, ancestors): type = self.context.get_input_type() if isinstance(type, GraphQLNonNull) and default_value: - return GraphQLError( + self.context.report_error(GraphQLError( self.default_for_non_null_arg_message(name, type, type.of_type), [default_value] - ) + )) if type and default_value: errors = is_valid_literal_value(type, default_value) if errors: - return GraphQLError( + self.context.report_error(GraphQLError( self.bad_value_for_default_arg_message(name, type, print_ast(default_value), errors), [default_value] - ) + )) return False def enter_SelectionSet(self, node, key, parent, path, ancestors): diff --git a/graphql/core/validation/rules/fields_on_correct_type.py b/graphql/core/validation/rules/fields_on_correct_type.py index 79a1dc73..5d81f45a 100644 --- a/graphql/core/validation/rules/fields_on_correct_type.py +++ b/graphql/core/validation/rules/fields_on_correct_type.py @@ -10,10 +10,10 @@ def enter_Field(self, node, key, parent, path, ancestors): field_def = self.context.get_field_def() if not field_def: - return GraphQLError( + self.context.report_error(GraphQLError( self.undefined_field_message(node.name.value, type.name), [node] - ) + )) @staticmethod def undefined_field_message(field_name, type): diff --git a/graphql/core/validation/rules/fragments_on_composite_types.py b/graphql/core/validation/rules/fragments_on_composite_types.py index 84d74e25..fa819c3b 100644 --- a/graphql/core/validation/rules/fragments_on_composite_types.py +++ b/graphql/core/validation/rules/fragments_on_composite_types.py @@ -9,19 +9,19 @@ def enter_InlineFragment(self, node, key, parent, path, ancestors): type = self.context.get_type() if node.type_condition and type and not is_composite_type(type): - return GraphQLError( + self.context.report_error(GraphQLError( self.inline_fragment_on_non_composite_error_message(print_ast(node.type_condition)), [node.type_condition] - ) + )) def enter_FragmentDefinition(self, node, key, parent, path, ancestors): type = self.context.get_type() if type and not is_composite_type(type): - return GraphQLError( + self.context.report_error(GraphQLError( self.fragment_on_non_composite_error_message(node.name.value, print_ast(node.type_condition)), [node.type_condition] - ) + )) @staticmethod def inline_fragment_on_non_composite_error_message(type): diff --git a/graphql/core/validation/rules/known_argument_names.py b/graphql/core/validation/rules/known_argument_names.py index 86b3de3e..c899fe12 100644 --- a/graphql/core/validation/rules/known_argument_names.py +++ b/graphql/core/validation/rules/known_argument_names.py @@ -17,10 +17,10 @@ def enter_Argument(self, node, key, parent, path, ancestors): if not field_arg_def: parent_type = self.context.get_parent_type() assert parent_type - return GraphQLError( + self.context.report_error(GraphQLError( self.unknown_arg_message(node.name.value, field_def.name, parent_type.name), [node] - ) + )) elif isinstance(argument_of, ast.Directive): directive = self.context.get_directive() @@ -30,10 +30,10 @@ def enter_Argument(self, node, key, parent, path, ancestors): directive_arg_def = next((arg for arg in directive.args if arg.name == node.name.value), None) if not directive_arg_def: - return GraphQLError( + self.context.report_error(GraphQLError( self.unknown_directive_arg_message(node.name.value, directive.name), [node] - ) + )) @staticmethod def unknown_arg_message(arg_name, field_name, type): diff --git a/graphql/core/validation/rules/known_directives.py b/graphql/core/validation/rules/known_directives.py index 651024cc..4af444ed 100644 --- a/graphql/core/validation/rules/known_directives.py +++ b/graphql/core/validation/rules/known_directives.py @@ -11,31 +11,31 @@ def enter_Directive(self, node, key, parent, path, ancestors): ), None) if not directive_def: - return GraphQLError( + return self.context.report_error(GraphQLError( self.unknown_directive_message(node.name.value), [node] - ) + )) applied_to = ancestors[-1] if isinstance(applied_to, ast.OperationDefinition) and not directive_def.on_operation: - return GraphQLError( + self.context.report_error(GraphQLError( self.misplaced_directive_message(node.name.value, 'operation'), [node] - ) + )) - if isinstance(applied_to, ast.Field) and not directive_def.on_field: - return GraphQLError( + elif isinstance(applied_to, ast.Field) and not directive_def.on_field: + self.context.report_error(GraphQLError( self.misplaced_directive_message(node.name.value, 'field'), [node] - ) + )) - if (isinstance(applied_to, (ast.FragmentSpread, ast.InlineFragment, ast.FragmentDefinition)) and + elif (isinstance(applied_to, (ast.FragmentSpread, ast.InlineFragment, ast.FragmentDefinition)) and not directive_def.on_fragment): - return GraphQLError( + self.context.report_error(GraphQLError( self.misplaced_directive_message(node.name.value, 'fragment'), [node] - ) + )) @staticmethod def unknown_directive_message(directive_name): diff --git a/graphql/core/validation/rules/known_fragment_names.py b/graphql/core/validation/rules/known_fragment_names.py index c4a5b3b5..5e7d35d8 100644 --- a/graphql/core/validation/rules/known_fragment_names.py +++ b/graphql/core/validation/rules/known_fragment_names.py @@ -8,10 +8,10 @@ def enter_FragmentSpread(self, node, key, parent, path, ancestors): fragment = self.context.get_fragment(fragment_name) if not fragment: - return GraphQLError( + self.context.report_error(GraphQLError( self.unknown_fragment_message(fragment_name), [node.name] - ) + )) @staticmethod def unknown_fragment_message(fragment_name): diff --git a/graphql/core/validation/rules/known_type_names.py b/graphql/core/validation/rules/known_type_names.py index 1112a4e8..66138dbe 100644 --- a/graphql/core/validation/rules/known_type_names.py +++ b/graphql/core/validation/rules/known_type_names.py @@ -8,7 +8,7 @@ def enter_NamedType(self, node, *args): type = self.context.get_schema().get_type(type_name) if not type: - return GraphQLError(self.unknown_type_message(type_name), [node]) + self.context.report_error(GraphQLError(self.unknown_type_message(type_name), [node])) @staticmethod def unknown_type_message(type): diff --git a/graphql/core/validation/rules/lone_anonymous_operation.py b/graphql/core/validation/rules/lone_anonymous_operation.py index cc402b77..462a67f3 100644 --- a/graphql/core/validation/rules/lone_anonymous_operation.py +++ b/graphql/core/validation/rules/lone_anonymous_operation.py @@ -16,7 +16,7 @@ def enter_Document(self, node, key, parent, path, ancestors): def enter_OperationDefinition(self, node, key, parent, path, ancestors): if not node.name and self.operation_count > 1: - return GraphQLError(self.anonymous_operation_not_alone_message(), [node]) + self.context.report_error(GraphQLError(self.anonymous_operation_not_alone_message(), [node])) @staticmethod def anonymous_operation_not_alone_message(): diff --git a/graphql/core/validation/rules/no_fragment_cycles.py b/graphql/core/validation/rules/no_fragment_cycles.py index 069f0556..94e752c2 100644 --- a/graphql/core/validation/rules/no_fragment_cycles.py +++ b/graphql/core/validation/rules/no_fragment_cycles.py @@ -14,10 +14,6 @@ def __init__(self, context): self.spread_path = [] self.spread_path_index_by_name = {} - def leave_Document(self, node, key, parent, path, ancestors): - if self.errors: - return self.errors - def enter_OperationDefinition(self, node, key, parent, path, ancestors): return False @@ -49,7 +45,7 @@ def detect_cycle_recursive(self, fragment): self.spread_path.pop() else: cycle_path = self.spread_path[cycle_index:] - self.errors.append(GraphQLError( + self.context.report_error(GraphQLError( self.cycle_error_message( spread_name, [s.name.value for s in cycle_path] diff --git a/graphql/core/validation/rules/no_undefined_variables.py b/graphql/core/validation/rules/no_undefined_variables.py index bde1e572..75493a0f 100644 --- a/graphql/core/validation/rules/no_undefined_variables.py +++ b/graphql/core/validation/rules/no_undefined_variables.py @@ -23,19 +23,15 @@ def enter_OperationDefinition(self, operation, key, parent, path, ancestors): def leave_OperationDefinition(self, operation, key, parent, path, ancestors): usages = self.context.get_recursive_variable_usages(operation) - errors = [] for variable_usage in usages: node = variable_usage.node var_name = node.name.value if var_name not in self.defined_variable_names: - errors.append(GraphQLError( + self.context.report_error(GraphQLError( self.undefined_var_message(var_name, operation.name and operation.name.value), [node, operation] )) - if errors: - return errors - def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.defined_variable_names.add(node.variable.name.value) diff --git a/graphql/core/validation/rules/no_unused_fragments.py b/graphql/core/validation/rules/no_unused_fragments.py index 3b3c5041..8d35f483 100644 --- a/graphql/core/validation/rules/no_unused_fragments.py +++ b/graphql/core/validation/rules/no_unused_fragments.py @@ -26,17 +26,12 @@ def leave_Document(self, node, key, parent, path, ancestors): for fragment in fragments: fragment_names_used.add(fragment.name.value) - errors = [ - GraphQLError( - self.unused_fragment_message(fragment_definition.name.value), - [fragment_definition] - ) - for fragment_definition in self.fragment_definitions - if fragment_definition.name.value not in fragment_names_used - ] - - if errors: - return errors + for fragment_definition in self.fragment_definitions: + if fragment_definition.name.value not in fragment_names_used: + self.context.report_error(GraphQLError( + self.unused_fragment_message(fragment_definition.name.value), + [fragment_definition] + )) @staticmethod def unused_fragment_message(fragment_name): diff --git a/graphql/core/validation/rules/no_unused_variables.py b/graphql/core/validation/rules/no_unused_variables.py index c529837d..cd0fa9a7 100644 --- a/graphql/core/validation/rules/no_unused_variables.py +++ b/graphql/core/validation/rules/no_unused_variables.py @@ -19,17 +19,12 @@ def leave_OperationDefinition(self, operation, key, parent, path, ancestors): for variable_usage in usages: variable_name_used.add(variable_usage.node.name.value) - errors = [ - GraphQLError( - self.unused_variable_message(variable_definition.variable.name.value), - [variable_definition] - ) - for variable_definition in self.variable_definitions - if variable_definition.variable.name.value not in variable_name_used - ] - - if errors: - return errors + for variable_definition in self.variable_definitions: + if variable_definition.variable.name.value not in variable_name_used: + self.context.report_error(GraphQLError( + self.unused_variable_message(variable_definition.variable.name.value), + [variable_definition] + )) def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.variable_definitions.append(node) diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index 86b46340..f6500ee2 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -123,10 +123,8 @@ def leave_SelectionSet(self, node, key, parent, path, ancestors): conflicts = self.find_conflicts(field_map) if conflicts: - return [ - GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields1)+list(fields2)) for - (reason_name, reason), fields1, fields2 in conflicts - ] + for (reason_name, reason), fields1, fields2 in conflicts: + self.context.report_error(GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields1)+list(fields2))) @staticmethod def same_type(type1, type2): diff --git a/graphql/core/validation/rules/possible_fragment_spreads.py b/graphql/core/validation/rules/possible_fragment_spreads.py index fd41e132..49fd899f 100644 --- a/graphql/core/validation/rules/possible_fragment_spreads.py +++ b/graphql/core/validation/rules/possible_fragment_spreads.py @@ -13,20 +13,20 @@ def enter_InlineFragment(self, node, key, parent, path, ancestors): frag_type = self.context.get_type() parent_type = self.context.get_parent_type() if frag_type and parent_type and not self.do_types_overlap(frag_type, parent_type): - return GraphQLError( + self.context.report_error(GraphQLError( self.type_incompatible_anon_spread_message(parent_type, frag_type), [node] - ) + )) def enter_FragmentSpread(self, node, key, parent, path, ancestors): frag_name = node.name.value frag_type = self.get_fragment_type(self.context, frag_name) parent_type = self.context.get_parent_type() if frag_type and parent_type and not self.do_types_overlap(frag_type, parent_type): - return GraphQLError( + self.context.report_error(GraphQLError( self.type_incompatible_spread_message(frag_name, parent_type, frag_type), [node] - ) + )) @staticmethod def get_fragment_type(context, name): diff --git a/graphql/core/validation/rules/provided_non_null_arguments.py b/graphql/core/validation/rules/provided_non_null_arguments.py index 7d5f794d..d67d75e5 100644 --- a/graphql/core/validation/rules/provided_non_null_arguments.py +++ b/graphql/core/validation/rules/provided_non_null_arguments.py @@ -9,41 +9,33 @@ def leave_Field(self, node, key, parent, path, ancestors): if not field_def: return False - errors = [] arg_asts = node.arguments or [] arg_ast_map = {arg.name.value: arg for arg in arg_asts} for arg_def in field_def.args: arg_ast = arg_ast_map.get(arg_def.name, None) if not arg_ast and isinstance(arg_def.type, GraphQLNonNull): - errors.append(GraphQLError( + self.context.report_error(GraphQLError( self.missing_field_arg_message(node.name.value, arg_def.name, arg_def.type), [node] )) - if errors: - return errors - def leave_Directive(self, node, key, parent, path, ancestors): directive_def = self.context.get_directive() if not directive_def: return False - errors = [] arg_asts = node.arguments or [] arg_ast_map = {arg.name.value: arg for arg in arg_asts} for arg_def in directive_def.args: arg_ast = arg_ast_map.get(arg_def.name, None) if not arg_ast and isinstance(arg_def.type, GraphQLNonNull): - errors.append(GraphQLError( + self.context.report_error(GraphQLError( self.missing_directive_arg_message(node.name.value, arg_def.name, arg_def.type), [node] )) - if errors: - return errors - @staticmethod def missing_field_arg_message(name, arg_name, type): return 'Field "{}" argument "{}" of type "{}" is required but not provided.'.format(name, arg_name, type) diff --git a/graphql/core/validation/rules/scalar_leafs.py b/graphql/core/validation/rules/scalar_leafs.py index 64e0249a..d4794c2e 100644 --- a/graphql/core/validation/rules/scalar_leafs.py +++ b/graphql/core/validation/rules/scalar_leafs.py @@ -12,16 +12,16 @@ def enter_Field(self, node, key, parent, path, ancestors): if is_leaf_type(type): if node.selection_set: - return GraphQLError( + self.context.report_error(GraphQLError( self.no_subselection_allowed_message(node.name.value, type), [node.selection_set] - ) + )) elif not node.selection_set: - return GraphQLError( + self.context.report_error(GraphQLError( self.required_subselection_message(node.name.value, type), [node] - ) + )) @staticmethod def no_subselection_allowed_message(field, type): diff --git a/graphql/core/validation/rules/unique_argument_names.py b/graphql/core/validation/rules/unique_argument_names.py index 90c79ebf..9e25c75c 100644 --- a/graphql/core/validation/rules/unique_argument_names.py +++ b/graphql/core/validation/rules/unique_argument_names.py @@ -19,12 +19,12 @@ def enter_Argument(self, node, key, parent, path, ancestors): arg_name = node.name.value if arg_name in self.known_arg_names: - return GraphQLError( + self.context.report_error(GraphQLError( self.duplicate_arg_message(arg_name), [self.known_arg_names[arg_name], node.name] - ) - - self.known_arg_names[arg_name] = node.name + )) + else: + self.known_arg_names[arg_name] = node.name return False @staticmethod diff --git a/graphql/core/validation/rules/unique_fragment_names.py b/graphql/core/validation/rules/unique_fragment_names.py index d36e9254..91de3271 100644 --- a/graphql/core/validation/rules/unique_fragment_names.py +++ b/graphql/core/validation/rules/unique_fragment_names.py @@ -15,12 +15,12 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): def enter_FragmentDefinition(self, node, key, parent, path, ancestors): fragment_name = node.name.value if fragment_name in self.known_fragment_names: - return GraphQLError( + self.context.report_error(GraphQLError( self.duplicate_fragment_name_message(fragment_name), [self.known_fragment_names[fragment_name], node.name] - ) - - self.known_fragment_names[fragment_name] = node.name + )) + else: + self.known_fragment_names[fragment_name] = node.name return False @staticmethod diff --git a/graphql/core/validation/rules/unique_input_field_names.py b/graphql/core/validation/rules/unique_input_field_names.py index 534fe688..6fe18bab 100644 --- a/graphql/core/validation/rules/unique_input_field_names.py +++ b/graphql/core/validation/rules/unique_input_field_names.py @@ -20,12 +20,12 @@ def leave_ObjectValue(self, node, key, parent, path, ancestors): def enter_ObjectField(self, node, key, parent, path, ancestors): field_name = node.name.value if field_name in self.known_names: - return GraphQLError( + self.context.report_error(GraphQLError( self.duplicate_input_field_message(field_name), [self.known_names[field_name], node.name] - ) - - self.known_names[field_name] = node.name + )) + else: + self.known_names[field_name] = node.name return False @staticmethod diff --git a/graphql/core/validation/rules/unique_operation_names.py b/graphql/core/validation/rules/unique_operation_names.py index 07893395..1fccbb9b 100644 --- a/graphql/core/validation/rules/unique_operation_names.py +++ b/graphql/core/validation/rules/unique_operation_names.py @@ -15,12 +15,12 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): return if operation_name.value in self.known_operation_names: - return GraphQLError( + self.context.report_error(GraphQLError( self.duplicate_operation_name_message(operation_name.value), [self.known_operation_names[operation_name.value], operation_name] - ) - - self.known_operation_names[operation_name.value] = operation_name + )) + else: + self.known_operation_names[operation_name.value] = operation_name return False def enter_FragmentDefinition(self, node, key, parent, path, ancestors): diff --git a/graphql/core/validation/rules/variables_are_input_types.py b/graphql/core/validation/rules/variables_are_input_types.py index b6805618..4ea29022 100644 --- a/graphql/core/validation/rules/variables_are_input_types.py +++ b/graphql/core/validation/rules/variables_are_input_types.py @@ -10,10 +10,10 @@ def enter_VariableDefinition(self, node, key, parent, path, ancestors): type = type_from_ast(self.context.get_schema(), node.type) if type and not is_input_type(type): - return GraphQLError( + self.context.report_error(GraphQLError( self.non_input_type_on_variable_message(node.variable.name.value, print_ast(node.type)), [node.type] - ) + )) @staticmethod def non_input_type_on_variable_message(variable_name, type_name): diff --git a/graphql/core/validation/rules/variables_in_allowed_position.py b/graphql/core/validation/rules/variables_in_allowed_position.py index b4ecd884..2303025c 100644 --- a/graphql/core/validation/rules/variables_in_allowed_position.py +++ b/graphql/core/validation/rules/variables_in_allowed_position.py @@ -19,7 +19,6 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors): def leave_OperationDefinition(self, operation, key, parent, path, ancestors): usages = self.context.get_recursive_variable_usages(operation) - errors = [] for usage in usages: node = usage.node @@ -28,14 +27,11 @@ def leave_OperationDefinition(self, operation, key, parent, path, ancestors): var_def = self.var_def_map.get(var_name) var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type) if var_type and type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), type): - errors.append(GraphQLError( + self.context.report_error(GraphQLError( self.bad_var_pos_message(var_name, var_type, type), [node] )) - if errors: - return errors - def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.var_def_map[node.variable.name.value] = node diff --git a/tests/core_validation/test_fields_on_correct_type.py b/tests/core_validation/test_fields_on_correct_type.py index 421a5a3d..1a731a17 100644 --- a/tests/core_validation/test_fields_on_correct_type.py +++ b/tests/core_validation/test_fields_on_correct_type.py @@ -61,6 +61,21 @@ def test_ignores_fields_on_unknown_type(): ''') +def test_reports_errors_when_type_is_known_again(): + expect_fails_rule(FieldsOnCorrectType, ''' + fragment typeKnownAgain on Pet { + unknown_pet_field { + ... on Cat { + unknown_cat_field + } + } + }, + ''', [ + undefined_field('unknown_pet_field', 'Pet', 3, 9), + undefined_field('unknown_cat_field', 'Cat', 5, 13) + ]) + + def test_field_not_defined_on_fragment(): expect_fails_rule(FieldsOnCorrectType, ''' fragment fieldNotDefined on Dog { @@ -71,7 +86,7 @@ def test_field_not_defined_on_fragment(): ]) -def test_field_not_defined_deeply_only_reports_first(): +def test_ignores_deeply_unknown_field(): expect_fails_rule(FieldsOnCorrectType, ''' fragment deepFieldNotDefined on Dog { unknown_field { diff --git a/tests/core_validation/test_validation.py b/tests/core_validation/test_validation.py index 4552da6e..5479a094 100644 --- a/tests/core_validation/test_validation.py +++ b/tests/core_validation/test_validation.py @@ -49,5 +49,7 @@ def test_validates_using_a_custom_type_info(): specified_rules ) - assert len(errors) == 1 - assert errors[0].message == 'Cannot query field "catOrDog" on "QueryRoot".' \ No newline at end of file + assert len(errors) == 3 + assert errors[0].message == 'Cannot query field "catOrDog" on "QueryRoot".' + assert errors[1].message == 'Cannot query field "furColor" on "Cat".' + assert errors[2].message == 'Cannot query field "isHousetrained" on "Dog".' From 39b5fb33340c10ee2a2f7e0bafe7ccdec4831f6f Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 12 Apr 2016 23:08:35 -0700 Subject: [PATCH 10/16] [Validation] Unwrap recursion for collecting fragment spreads Related GraphQL-js commit: https://github.com/graphql/graphql-js/commit/4e610b72626c27a2fa12784cbfadd1ff4b7e1660 --- graphql/core/validation/context.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index 20021cc7..3345ed89 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -100,7 +100,15 @@ def get_fragment_spreads(self, node): spreads = self._fragment_spreads.get(node) if not spreads: spreads = [] - self.gather_spreads(spreads, node.selection_set) + sets_to_visit = [node.selection_set] + while sets_to_visit: + _set = sets_to_visit.pop() + for selection in _set.selections: + if isinstance(selection, FragmentSpread): + spreads.append(selection) + elif selection.selection_set: + sets_to_visit.append(selection.selection_set) + self._fragment_spreads[node] = spreads return spreads @@ -133,11 +141,3 @@ def get_directive(self): def get_argument(self): return self._type_info.get_argument() - - @classmethod - def gather_spreads(cls, spreads, node): - for selection in node.selections: - if isinstance(selection, FragmentSpread): - spreads.append(selection) - elif selection.selection_set: - cls.gather_spreads(spreads, selection.selection_set) From 6cbf0b2bc11fc7961d8b59f213d17cf541ef0b89 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Wed, 13 Apr 2016 01:19:43 -0700 Subject: [PATCH 11/16] Simplify validator + [Validation] Remove visitFragmentSpreads Related GraphQL-js commits: https://github.com/graphql/graphql-js/commit/7d6e981889e727087f1d47144c364404154bc468#diff-5cc3eacf5497787132ea1c3fc2d6931b + https://github.com/graphql/graphql-js/blob/c73cc727e06a29140615d3b8beb452b3e2d09e58/src/validation/validate.js --- graphql/core/validation/__init__.py | 4 +- graphql/core/validation/visitor.py | 95 +++-------------------------- 2 files changed, 11 insertions(+), 88 deletions(-) diff --git a/graphql/core/validation/__init__.py b/graphql/core/validation/__init__.py index bac1016d..97cc69a6 100644 --- a/graphql/core/validation/__init__.py +++ b/graphql/core/validation/__init__.py @@ -16,7 +16,7 @@ def validate(schema, ast, rules=specified_rules): def visit_using_rules(schema, type_info, ast, rules): context = ValidationContext(schema, ast, type_info) - errors = [] rules = [rule(context) for rule in rules] - visit(ast, ValidationVisitor(rules, context, type_info, errors)) + for instance in rules: + visit(ast, ValidationVisitor(instance, context, type_info)) return context.get_errors() diff --git a/graphql/core/validation/visitor.py b/graphql/core/validation/visitor.py index 5af79178..4d9ddb25 100644 --- a/graphql/core/validation/visitor.py +++ b/graphql/core/validation/visitor.py @@ -1,101 +1,24 @@ -from ..error import GraphQLError -from ..language.ast import FragmentDefinition, FragmentSpread -from ..language.visitor import Visitor, visit +from ..language.visitor import Visitor class ValidationVisitor(Visitor): - __slots__ = 'context', 'rules', 'total_rules', 'type_info', 'errors', 'ignore_children' + __slots__ = 'context', 'instance', 'type_info' - def __init__(self, rules, context, type_info, errors): + def __init__(self, instance, context, type_info): self.context = context - self.rules = rules - self.total_rules = len(rules) + self.instance = instance self.type_info = type_info - self.errors = errors - self.ignore_children = {} def enter(self, node, key, parent, path, ancestors): self.type_info.enter(node) - to_ignore = None - rules_wanting_to_visit_fragment = None - - skipped = 0 - for rule in self.rules: - if rule in self.ignore_children: - skipped += 1 - continue - - visit_spread_fragments = getattr(rule, 'visit_spread_fragments', False) - - if isinstance(node, FragmentDefinition) and key and visit_spread_fragments: - if to_ignore is None: - to_ignore = [] - - to_ignore.append(rule) - continue - - result = rule.enter(node, key, parent, path, ancestors) - - if result and is_error(result): - append(self.errors, result) - result = False - - if result is None and visit_spread_fragments and isinstance(node, FragmentSpread): - if rules_wanting_to_visit_fragment is None: - rules_wanting_to_visit_fragment = [] - - rules_wanting_to_visit_fragment.append(rule) - - if result is False: - if to_ignore is None: - to_ignore = [] - - to_ignore.append(rule) - - if rules_wanting_to_visit_fragment: - fragment = self.context.get_fragment(node.name.value) - - if fragment: - sub_visitor = ValidationVisitor(rules_wanting_to_visit_fragment, self.context, self.type_info, - self.errors) - visit(fragment, sub_visitor) - - should_skip = (len(to_ignore) if to_ignore else 0 + skipped) == self.total_rules - - if should_skip: + result = self.instance.enter(node, key, parent, path, ancestors) + if result is False: self.type_info.leave(node) - elif to_ignore: - for rule in to_ignore: - self.ignore_children[rule] = node - - if should_skip: - return False + return result def leave(self, node, key, parent, path, ancestors): - for rule in self.rules: - if rule in self.ignore_children: - if self.ignore_children[rule] is node: - del self.ignore_children[rule] - - continue - - result = rule.leave(node, key, parent, path, ancestors) - - if result and is_error(result): - append(self.errors, result) + result = self.instance.leave(node, key, parent, path, ancestors) self.type_info.leave(node) - - -def is_error(value): - if isinstance(value, list): - return all(isinstance(item, GraphQLError) for item in value) - return isinstance(value, GraphQLError) - - -def append(arr, items): - if isinstance(items, list): - arr.extend(items) - else: - arr.append(items) + return result From 68633ffeec83d633f34783d112beba7dce4a7cfa Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Wed, 13 Apr 2016 23:29:30 -0700 Subject: [PATCH 12/16] [Validation] Parallelize validation rules. Related GraphQL-js commit: https://github.com/graphql/graphql-js/commit/957704188b0a103c5f2fe0ab99479267d5d1ae43 --- graphql/core/validation/__init__.py | 7 +++-- graphql/core/validation/context.py | 19 ++++++-------- graphql/core/validation/visitor.py | 40 +++++++++++++++++++++-------- 3 files changed, 40 insertions(+), 26 deletions(-) diff --git a/graphql/core/validation/__init__.py b/graphql/core/validation/__init__.py index 97cc69a6..201221e5 100644 --- a/graphql/core/validation/__init__.py +++ b/graphql/core/validation/__init__.py @@ -3,7 +3,7 @@ from ..utils.type_info import TypeInfo from .context import ValidationContext from .rules import specified_rules -from .visitor import ValidationVisitor +from .visitor import ParallelVisitor, TypeInfoVisitor def validate(schema, ast, rules=specified_rules): @@ -16,7 +16,6 @@ def validate(schema, ast, rules=specified_rules): def visit_using_rules(schema, type_info, ast, rules): context = ValidationContext(schema, ast, type_info) - rules = [rule(context) for rule in rules] - for instance in rules: - visit(ast, ValidationVisitor(instance, context, type_info)) + visitors = [rule(context) for rule in rules] + visit(ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors))) return context.get_errors() diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index 3345ed89..eb507b90 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -1,6 +1,7 @@ from ..language.ast import FragmentDefinition, FragmentSpread, VariableDefinition, Variable, OperationDefinition from ..utils.type_info import TypeInfo from ..language.visitor import Visitor, visit +from .visitor import TypeInfoVisitor class VariableUsage(object): @@ -12,22 +13,18 @@ def __init__(self, node, type): class UsageVisitor(Visitor): - __slots__ = 'context', 'usages', 'type_info' + __slots__ = 'usages', 'type_info' def __init__(self, usages, type_info): self.usages = usages self.type_info = type_info - def enter(self, node, key, parent, path, ancestors): - self.type_info.enter(node) - if isinstance(node, VariableDefinition): - return False - elif isinstance(node, Variable): - usage = VariableUsage(node, type=self.type_info.get_input_type()) - self.usages.append(usage) + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + return False - def leave(self, node, key, parent, path, ancestors): - self.type_info.leave(node) + def enter_Variable(self, node, key, parent, path, ancestors): + usage = VariableUsage(node, type=self.type_info.get_input_type()) + self.usages.append(usage) class ValidationContext(object): @@ -58,7 +55,7 @@ def get_variable_usages(self, node): if usages is None: usages = [] sub_visitor = UsageVisitor(usages, self._type_info) - visit(node, sub_visitor) + visit(node, TypeInfoVisitor(self._type_info, sub_visitor)) self._variable_usages[node] = usages return usages diff --git a/graphql/core/validation/visitor.py b/graphql/core/validation/visitor.py index 4d9ddb25..70f6f630 100644 --- a/graphql/core/validation/visitor.py +++ b/graphql/core/validation/visitor.py @@ -1,24 +1,42 @@ from ..language.visitor import Visitor -class ValidationVisitor(Visitor): - __slots__ = 'context', 'instance', 'type_info' +class TypeInfoVisitor(Visitor): + __slots__ = 'visitor', 'type_info' - def __init__(self, instance, context, type_info): - self.context = context - self.instance = instance + def __init__(self, type_info, visitor): self.type_info = type_info + self.visitor = visitor def enter(self, node, key, parent, path, ancestors): self.type_info.enter(node) - result = self.instance.enter(node, key, parent, path, ancestors) + result = self.visitor.enter(node, key, parent, path, ancestors) if result is False: self.type_info.leave(node) - - return result + return False def leave(self, node, key, parent, path, ancestors): - result = self.instance.leave(node, key, parent, path, ancestors) - + self.visitor.leave(node, key, parent, path, ancestors) self.type_info.leave(node) - return result + + +class ParallelVisitor(Visitor): + __slots__ = 'skipping', 'visitors' + + def __init__(self, visitors): + self.visitors = visitors + self.skipping = [None]*len(visitors) + + def enter(self, node, key, parent, path, ancestors): + for i, visitor in enumerate(self.visitors): + if not self.skipping[i]: + result = visitor.enter(node, key, parent, path, ancestors) + if result is False: + self.skipping[i] = node + + def leave(self, node, key, parent, path, ancestors): + for i, visitor in enumerate(self.visitors): + if not self.skipping[i]: + visitor.leave(node, key, parent, path, ancestors) + else: + self.skipping[i] = None From b1bfadf35d18603861ab395300e11c8a1ec2800e Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Wed, 13 Apr 2016 23:36:36 -0700 Subject: [PATCH 13/16] [Validation] Include variable definition node when reporting bad var type Related GraphQL-js commits: https://github.com/graphql/graphql-js/commit/1d14db78b8e38d6cb5b0698dadc774ed794f7398, https://github.com/graphql/graphql-js/commit/e81cf39750e8c2dbde7a7910e13893aa644e02dd --- .../rules/variables_in_allowed_position.py | 13 +++---- .../test_variables_in_allowed_position.py | 35 ++++++++----------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/graphql/core/validation/rules/variables_in_allowed_position.py b/graphql/core/validation/rules/variables_in_allowed_position.py index 2303025c..9881205c 100644 --- a/graphql/core/validation/rules/variables_in_allowed_position.py +++ b/graphql/core/validation/rules/variables_in_allowed_position.py @@ -25,12 +25,13 @@ def leave_OperationDefinition(self, operation, key, parent, path, ancestors): type = usage.type var_name = node.name.value var_def = self.var_def_map.get(var_name) - var_type = var_def and type_from_ast(self.context.get_schema(), var_def.type) - if var_type and type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), type): - self.context.report_error(GraphQLError( - self.bad_var_pos_message(var_name, var_type, type), - [node] - )) + if var_def and type: + var_type = type_from_ast(self.context.get_schema(), var_def.type) + if var_type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), type): + self.context.report_error(GraphQLError( + self.bad_var_pos_message(var_name, var_type, type), + [var_def, node] + )) def enter_VariableDefinition(self, node, key, parent, path, ancestors): self.var_def_map[node.variable.name.value] = node diff --git a/tests/core_validation/test_variables_in_allowed_position.py b/tests/core_validation/test_variables_in_allowed_position.py index 19eae110..56bb119c 100644 --- a/tests/core_validation/test_variables_in_allowed_position.py +++ b/tests/core_validation/test_variables_in_allowed_position.py @@ -153,15 +153,14 @@ def test_boolean_non_null_boolean_in_directive_with_default(): def test_int_non_null_int(): expect_fails_rule(VariablesInAllowedPosition, ''' - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { nonNullIntArgField(nonNullIntArg: $intArg) } } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(5, 45)] } + 'locations': [SourceLocation(4, 45), SourceLocation(2, 19)] } ]) def test_int_non_null_int_within_fragment(): @@ -169,15 +168,14 @@ def test_int_non_null_int_within_fragment(): fragment nonNullIntArgFieldFrag on ComplicatedArgs { nonNullIntArgField(nonNullIntArg: $intArg) } - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { ...nonNullIntArgFieldFrag } } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(3, 43)] } + 'locations': [SourceLocation(5, 19), SourceLocation(3, 43)] } ]) def test_int_non_null_int_within_nested_fragment(): @@ -188,61 +186,56 @@ def test_int_non_null_int_within_nested_fragment(): fragment nonNullIntArgFieldFrag on ComplicatedArgs { nonNullIntArgField(nonNullIntArg: $intArg) } - query Query($intArg: Int) - { + query Query($intArg: Int) { complicatedArgs { ...outerFrag } } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(6, 43)] } + 'locations': [SourceLocation(8, 19), SourceLocation(6, 43)] } ]) def test_string_over_boolean(): expect_fails_rule(VariablesInAllowedPosition, ''' - query Query($stringVar: String) - { + query Query($stringVar: String) { complicatedArgs { booleanArgField(booleanArg: $stringVar) } } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean'), - 'locations': [SourceLocation(5, 39)] } + 'locations': [SourceLocation(2, 19), SourceLocation(4, 39)] } ]) def test_string_string_fail(): expect_fails_rule(VariablesInAllowedPosition, ''' - query Query($stringVar: String) - { + query Query($stringVar: String) { complicatedArgs { stringListArgField(stringListArg: $stringVar) } } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', '[String]'), - 'locations': [SourceLocation(5, 45)]} + 'locations': [SourceLocation(2, 19), SourceLocation(4, 45)]} ]) def test_boolean_non_null_boolean_in_directive(): expect_fails_rule(VariablesInAllowedPosition, ''' - query Query($boolVar: Boolean) - { + query Query($boolVar: Boolean) { dog @include(if: $boolVar) } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('boolVar', 'Boolean', 'Boolean!'), - 'locations': [SourceLocation(4, 26)] + 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)] }]) def test_string_non_null_boolean_in_directive(): expect_fails_rule(VariablesInAllowedPosition, ''' - query Query($stringVar: String) - { + query Query($stringVar: String) { dog @include(if: $stringVar) } ''', [ { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean!'), - 'locations': [SourceLocation(4, 26)] } + 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)] } ]) From bc9a0288c4452128ea82f7472d5801c6a8864a09 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 14 Apr 2016 00:34:31 -0700 Subject: [PATCH 14/16] [Schema] Implementing interfaces with covariant return types. Related GraphQL-js commit: https://github.com/graphql/graphql-js/commit/edbe063718590d84594f2b8e06dc6fd67e1f3ec2 --- graphql/core/type/schema.py | 13 +- graphql/core/utils/type_comparators.py | 43 +++++ .../rules/overlapping_fields_can_be_merged.py | 4 +- .../rules/variables_in_allowed_position.py | 20 +-- tests/core_type/test_validation.py | 161 +++++++++++++++++- 5 files changed, 208 insertions(+), 33 deletions(-) create mode 100644 graphql/core/utils/type_comparators.py diff --git a/graphql/core/type/schema.py b/graphql/core/type/schema.py index 3bd6dc3b..fe7e3ec8 100644 --- a/graphql/core/type/schema.py +++ b/graphql/core/type/schema.py @@ -9,6 +9,7 @@ ) from .directives import GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective from .introspection import IntrospectionSchema +from ..utils.type_comparators import is_equal_type, is_type_sub_type_of class GraphQLSchema(object): @@ -144,7 +145,7 @@ def assert_object_implements_interface(object, interface): interface, field_name, object ) - assert is_equal_type(interface_field.type, object_field.type), ( + assert is_type_sub_type_of(object_field.type, interface_field.type), ( '{}.{} expects type "{}" but {}.{} provides type "{}".' ).format(interface, field_name, interface_field.type, object, field_name, object_field.type) @@ -171,13 +172,3 @@ def assert_object_implements_interface(object, interface): '"{}" but is not also provided by the ' 'interface {}.{}.' ).format(object, field_name, arg_name, object_arg.type, interface, field_name) - - -def is_equal_type(type_a, type_b): - if isinstance(type_a, GraphQLNonNull) and isinstance(type_b, GraphQLNonNull): - return is_equal_type(type_a.of_type, type_b.of_type) - - if isinstance(type_a, GraphQLList) and isinstance(type_b, GraphQLList): - return is_equal_type(type_a.of_type, type_b.of_type) - - return type_a is type_b diff --git a/graphql/core/utils/type_comparators.py b/graphql/core/utils/type_comparators.py new file mode 100644 index 00000000..570fa77b --- /dev/null +++ b/graphql/core/utils/type_comparators.py @@ -0,0 +1,43 @@ +from ..type.definition import ( + is_abstract_type, + GraphQLObjectType, + GraphQLList, + GraphQLNonNull, +) + + +def is_equal_type(type_a, type_b): + if type_a is type_b: + return True + + if isinstance(type_a, GraphQLNonNull) and isinstance(type_b, GraphQLNonNull): + return is_equal_type(type_a.of_type, type_b.of_type) + + if isinstance(type_a, GraphQLList) and isinstance(type_b, GraphQLList): + return is_equal_type(type_a.of_type, type_b.of_type) + + return False + + +def is_type_sub_type_of(maybe_subtype, super_type): + if maybe_subtype is super_type: + return True + + if isinstance(super_type, GraphQLNonNull): + if isinstance(maybe_subtype, GraphQLNonNull): + return is_type_sub_type_of(maybe_subtype.of_type, super_type.of_type) + return False + elif isinstance(maybe_subtype, GraphQLNonNull): + return is_type_sub_type_of(maybe_subtype.of_type, super_type) + + if isinstance(super_type, GraphQLList): + if isinstance(maybe_subtype, GraphQLList): + return is_type_sub_type_of(maybe_subtype.of_type, super_type.of_type) + return False + elif isinstance(maybe_subtype, GraphQLList): + return False + + if is_abstract_type(super_type) and isinstance(maybe_subtype, GraphQLObjectType) and super_type.is_possible_type(maybe_subtype): + return True + + return False diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index f6500ee2..735b3035 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -11,6 +11,7 @@ get_named_type, ) from ...utils.type_from_ast import type_from_ast +from ...utils.type_comparators import is_equal_type from .base import ValidationRule @@ -128,7 +129,8 @@ def leave_SelectionSet(self, node, key, parent, path, ancestors): @staticmethod def same_type(type1, type2): - return type1.is_same_type(type2) + return is_equal_type(type1, type2) + # return type1.is_same_type(type2) @staticmethod def same_value(value1, value2): diff --git a/graphql/core/validation/rules/variables_in_allowed_position.py b/graphql/core/validation/rules/variables_in_allowed_position.py index 9881205c..11e45580 100644 --- a/graphql/core/validation/rules/variables_in_allowed_position.py +++ b/graphql/core/validation/rules/variables_in_allowed_position.py @@ -1,9 +1,9 @@ from ...error import GraphQLError from ...type.definition import ( - GraphQLList, GraphQLNonNull ) from ...utils.type_from_ast import type_from_ast +from ...utils.type_comparators import is_type_sub_type_of from .base import ValidationRule @@ -27,7 +27,7 @@ def leave_OperationDefinition(self, operation, key, parent, path, ancestors): var_def = self.var_def_map.get(var_name) if var_def and type: var_type = type_from_ast(self.context.get_schema(), var_def.type) - if var_type and not self.var_type_allowed_for_type(self.effective_type(var_type, var_def), type): + if var_type and not is_type_sub_type_of(self.effective_type(var_type, var_def), type): self.context.report_error(GraphQLError( self.bad_var_pos_message(var_name, var_type, type), [var_def, node] @@ -43,22 +43,6 @@ def effective_type(var_type, var_def): return GraphQLNonNull(var_type) - @classmethod - def var_type_allowed_for_type(cls, var_type, expected_type): - if isinstance(expected_type, GraphQLNonNull): - if isinstance(var_type, GraphQLNonNull): - return cls.var_type_allowed_for_type(var_type.of_type, expected_type.of_type) - - return False - - if isinstance(var_type, GraphQLNonNull): - return cls.var_type_allowed_for_type(var_type.of_type, expected_type) - - if isinstance(var_type, GraphQLList) and isinstance(expected_type, GraphQLList): - return cls.var_type_allowed_for_type(var_type.of_type, expected_type.of_type) - - return var_type == expected_type - @staticmethod def bad_var_pos_message(var_name, var_type, expected_type): return 'Variable "{}" of type "{}" used in position expecting type "{}".'.format(var_name, var_type, diff --git a/tests/core_type/test_validation.py b/tests/core_type/test_validation.py index 65088a12..bad1390d 100644 --- a/tests/core_type/test_validation.py +++ b/tests/core_type/test_validation.py @@ -1459,6 +1459,98 @@ def test_rejects_an_object_with_an_incorrectly_typed_interface_argument(self): assert str(excinfo.value) == 'AnotherInterface.field(input:) expects type "String" ' \ 'but AnotherObject.field(input:) provides type "SomeScalar".' + def test_rejects_an_object_with_an_incorrectly_typed_interface_field(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields={ + 'field': GraphQLField(GraphQLString) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields={ + 'field': GraphQLField(SomeScalarType) + } + ) + + with raises(AssertionError) as excinfo: + schema_with_field_type(AnotherObject) + + assert str(excinfo.value) == 'AnotherInterface.field expects type "String" ' \ + 'but AnotherObject.field provides type "SomeScalar".' + + def test_rejects_an_object_with_a_differently_typed_Interface_field(self): + TypeA = GraphQLObjectType( + name='A', + fields={ + 'foo': GraphQLField(GraphQLString) + } + ) + TypeB = GraphQLObjectType( + name='B', + fields={ + 'foo': GraphQLField(GraphQLString) + } + ) + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields={ + 'field': GraphQLField(TypeA) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields={ + 'field': GraphQLField(TypeB) + } + ) + + with raises(AssertionError) as excinfo: + schema_with_field_type(AnotherObject) + + assert str(excinfo.value) == 'AnotherInterface.field expects type "A" but ' \ + 'AnotherObject.field provides type "B".' + + def test_accepts_an_object_with_a_subtyped_interface_field_interface(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields=lambda: { + 'field': GraphQLField(AnotherInterface) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields=lambda: { + 'field': GraphQLField(AnotherObject) + } + ) + + assert schema_with_field_type(AnotherObject) + + def test_accepts_an_object_with_a_subtyped_interface_field_union(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields=lambda: { + 'field': GraphQLField(SomeUnionType) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields=lambda: { + 'field': GraphQLField(SomeObjectType) + } + ) + + assert schema_with_field_type(AnotherObject) + def test_accepts_an_object_with_an_equivalently_modified_interface_field_type(self): AnotherInterface = GraphQLInterfaceType( name='AnotherInterface', @@ -1478,7 +1570,29 @@ def test_accepts_an_object_with_an_equivalently_modified_interface_field_type(se assert schema_with_field_type(AnotherObject) - def test_rejects_an_object_with_an_differently_modified_interface_field_type(self): + def test_rejects_an_object_with_a_non_list_interface_field_list_type(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields={ + 'field': GraphQLField(GraphQLList(GraphQLString)) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields={ + 'field': GraphQLField(GraphQLString) + } + ) + + with raises(AssertionError) as excinfo: + schema_with_field_type(AnotherObject) + + assert str(excinfo.value) == 'AnotherInterface.field expects type "[String]" ' \ + 'but AnotherObject.field provides type "String".' + + def test_rejects_a_object_with_a_list_interface_field_non_list_type(self): AnotherInterface = GraphQLInterfaceType( name='AnotherInterface', resolve_type=_none, @@ -1490,7 +1604,7 @@ def test_rejects_an_object_with_an_differently_modified_interface_field_type(sel name='AnotherObject', interfaces=[AnotherInterface], fields={ - 'field': GraphQLField(GraphQLNonNull(GraphQLString)) + 'field': GraphQLField(GraphQLList(GraphQLString)) } ) @@ -1499,4 +1613,45 @@ def test_rejects_an_object_with_an_differently_modified_interface_field_type(sel schema_with_field_type(AnotherObject) assert str(excinfo.value) == 'AnotherInterface.field expects type "String" ' \ - 'but AnotherObject.field provides type "String!".' + 'but AnotherObject.field provides type "[String]".' + + def test_accepts_an_object_with_a_subset_non_null_interface_field_type(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields={ + 'field': GraphQLField(GraphQLString) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields={ + 'field': GraphQLField(GraphQLNonNull(GraphQLString)) + } + ) + + assert schema_with_field_type(AnotherObject) + + def test_rejects_a_object_with_a_superset_nullable_interface_field_type(self): + AnotherInterface = GraphQLInterfaceType( + name='AnotherInterface', + resolve_type=_none, + fields={ + 'field': GraphQLField(GraphQLNonNull(GraphQLString)) + } + ) + AnotherObject = GraphQLObjectType( + name='AnotherObject', + interfaces=[AnotherInterface], + fields={ + 'field': GraphQLField(GraphQLString) + + } + ) + + with raises(AssertionError) as excinfo: + schema_with_field_type(AnotherObject) + + assert str(excinfo.value) == 'AnotherInterface.field expects type "String!" but ' \ + 'AnotherObject.field provides type "String".' From c81f6a74a5f933dad47108c32b045ba77045dcfe Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 14 Apr 2016 00:43:13 -0700 Subject: [PATCH 15/16] Fixed Python3 type mapping --- graphql/core/utils/extend_schema.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graphql/core/utils/extend_schema.py b/graphql/core/utils/extend_schema.py index 1a444e29..2d75fde2 100644 --- a/graphql/core/utils/extend_schema.py +++ b/graphql/core/utils/extend_schema.py @@ -145,12 +145,12 @@ def extend_union_type(type): return GraphQLUnionType( name=type.name, description=type.description, - types=map(get_type_from_def, type.get_possible_types()), + types=list(map(get_type_from_def, type.get_possible_types())), resolve_type=raise_client_schema_execution_error, ) def extend_implemented_interfaces(type): - interfaces = map(get_type_from_def, type.get_interfaces()) + interfaces = list(map(get_type_from_def, type.get_interfaces())) # If there are any extensions to the interfaces, apply those here. extensions = type_extensions_map[type.name] @@ -171,7 +171,7 @@ def extend_implemented_interfaces(type): def extend_field_map(type): new_field_map = OrderedDict() old_field_map = type.get_fields() - for field_name, field in old_field_map.iteritems(): + for field_name, field in old_field_map.items(): new_field_map[field_name] = GraphQLField( extend_field_type(field.type), description=field.description, @@ -237,7 +237,7 @@ def build_interface_type(type_ast): def build_union_type(type_ast): return GraphQLUnionType( type_ast.name.value, - types=map(get_type_from_AST, type_ast.types), + types=list(map(get_type_from_AST, type_ast.types)), resolve_type=raise_client_schema_execution_error, ) @@ -267,7 +267,7 @@ def build_input_object_type(type_ast): ) def build_implemented_interfaces(type_ast): - return map(get_type_from_AST, type_ast.interfaces) + return list(map(get_type_from_AST, type_ast.interfaces)) def build_field_map(type_ast): return { @@ -324,11 +324,11 @@ def build_field_type(type_ast): # Iterate through all types, getting the type definition for each, ensuring # that any type not directly referenced by a field will get created. - for typeName, _def in schema.get_type_map().iteritems(): + for typeName, _def in schema.get_type_map().items(): get_type_from_def(_def) # Do the same with new types. - for typeName, _def in type_definition_map.iteritems(): + for typeName, _def in type_definition_map.items(): get_type_from_AST(_def) # Then produce and return a Schema with these types. From 2ea2cd2f91ca9683a5877e8fc07d4bb054b564bb Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 14 Apr 2016 01:15:10 -0700 Subject: [PATCH 16/16] =?UTF-8?q?PEP8+Flake+isort=20=F0=9F=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bin/autolinter | 7 ++ graphql/core/execution/base.py | 24 ++-- graphql/core/execution/executor.py | 14 ++- graphql/core/execution/middlewares/asyncio.py | 1 + graphql/core/execution/middlewares/gevent.py | 1 + graphql/core/execution/middlewares/sync.py | 1 + graphql/core/execution/values.py | 12 +- graphql/core/language/error.py | 1 + graphql/core/language/lexer.py | 2 + graphql/core/language/parser.py | 1 + graphql/core/language/printer.py | 1 + graphql/core/language/visitor.py | 2 +- graphql/core/language/visitor_meta.py | 1 + graphql/core/pyutils/defer.py | 3 + graphql/core/type/definition.py | 4 +- graphql/core/type/introspection.py | 19 +--- graphql/core/type/scalars.py | 8 +- graphql/core/type/schema.py | 19 ++-- graphql/core/utils/ast_from_value.py | 9 +- graphql/core/utils/ast_to_code.py | 4 +- graphql/core/utils/build_ast_schema.py | 32 ++---- graphql/core/utils/build_client_schema.py | 40 +++---- graphql/core/utils/extend_schema.py | 3 +- graphql/core/utils/get_field_def.py | 10 +- graphql/core/utils/is_valid_literal_value.py | 15 +-- graphql/core/utils/is_valid_value.py | 11 +- graphql/core/utils/schema_printer.py | 11 +- graphql/core/utils/type_comparators.py | 11 +- graphql/core/utils/type_from_ast.py | 5 +- graphql/core/utils/type_info.py | 11 +- graphql/core/utils/value_from_ast.py | 3 +- graphql/core/validation/context.py | 7 +- .../rules/arguments_of_correct_type.py | 1 + .../rules/default_values_of_correct_type.py | 1 + .../rules/fields_on_correct_type.py | 1 + .../rules/fragments_on_composite_types.py | 1 + .../validation/rules/known_argument_names.py | 1 + .../core/validation/rules/known_directives.py | 1 + .../validation/rules/known_fragment_names.py | 1 + .../core/validation/rules/known_type_names.py | 1 + .../validation/rules/no_fragment_cycles.py | 4 +- .../rules/no_undefined_variables.py | 1 - .../rules/overlapping_fields_can_be_merged.py | 21 ++-- .../rules/possible_fragment_spreads.py | 8 +- .../rules/provided_non_null_arguments.py | 1 + graphql/core/validation/rules/scalar_leafs.py | 1 + .../rules/variables_are_input_types.py | 1 + .../rules/variables_in_allowed_position.py | 6 +- graphql/core/validation/visitor.py | 2 +- tests/core_execution/test_abstract.py | 10 +- .../test_concurrent_executor.py | 17 ++- tests/core_execution/test_default_executor.py | 3 +- tests/core_execution/test_deferred.py | 7 +- tests/core_execution/test_directives.py | 3 +- tests/core_execution/test_executor.py | 44 ++++---- tests/core_execution/test_executor_schema.py | 35 +++--- tests/core_execution/test_gevent.py | 15 +-- tests/core_execution/test_lists.py | 16 +-- tests/core_execution/test_middleware.py | 4 +- tests/core_execution/test_mutations.py | 9 +- tests/core_execution/test_nonnull.py | 8 +- tests/core_execution/test_union_interface.py | 17 ++- tests/core_execution/test_variables.py | 23 ++-- tests/core_execution/utils.py | 4 +- tests/core_language/fixtures.py | 2 +- tests/core_language/test_ast.py | 3 +- tests/core_language/test_lexer.py | 3 +- tests/core_language/test_parser.py | 105 +++++++++--------- tests/core_language/test_printer.py | 5 +- tests/core_language/test_schema_parser.py | 1 + tests/core_language/test_schema_printer.py | 7 +- tests/core_language/test_visitor.py | 10 +- tests/core_language/test_visitor_meta.py | 3 + .../core_pyutils/test_default_ordered_dict.py | 7 +- tests/core_starwars/starwars_schema.py | 20 ++-- tests/core_starwars/test_query.py | 17 +-- tests/core_starwars/test_validation.py | 3 +- tests/core_type/test_definition.py | 25 ++--- tests/core_type/test_enum_type.py | 19 ++-- tests/core_type/test_introspection.py | 24 ++-- tests/core_type/test_serialization.py | 9 +- tests/core_type/test_validation.py | 48 +++++--- tests/core_utils/test_ast_from_value.py | 8 +- tests/core_utils/test_ast_to_code.py | 4 +- tests/core_utils/test_build_ast_schema.py | 3 +- tests/core_utils/test_build_client_schema.py | 32 ++---- tests/core_utils/test_concat_ast.py | 1 - tests/core_utils/test_schema_printer.py | 28 ++--- .../test_arguments_of_correct_type.py | 16 ++- .../test_default_values_of_correct_type.py | 2 +- .../test_fields_on_correct_type.py | 2 +- .../test_fragments_on_composite_types.py | 2 +- .../test_known_argument_names.py | 2 +- .../core_validation/test_known_directives.py | 2 +- .../test_known_fragment_names.py | 2 +- .../core_validation/test_known_type_names.py | 2 +- .../test_lone_anonymous_operation.py | 2 +- .../test_no_fragment_cycles.py | 23 +++- .../test_no_undefined_variables.py | 2 +- .../test_no_unused_fragments.py | 2 +- .../test_no_unused_variables.py | 2 +- .../test_overlapping_fields_can_be_merged.py | 14 ++- .../test_possible_fragment_spreads.py | 41 +++++-- .../test_provided_non_null_arguments.py | 2 +- tests/core_validation/test_scalar_leafs.py | 2 +- .../test_unique_argument_names.py | 4 +- .../test_unique_fragment_names.py | 2 +- .../test_unique_input_field_names.py | 2 +- .../test_unique_operation_names.py | 2 +- tests/core_validation/test_validation.py | 3 +- .../test_variables_are_input_types.py | 2 +- .../test_variables_in_allowed_position.py | 48 +++++--- tests/core_validation/utils.py | 35 +++--- 113 files changed, 601 insertions(+), 565 deletions(-) create mode 100755 bin/autolinter diff --git a/bin/autolinter b/bin/autolinter new file mode 100755 index 00000000..cd4ef379 --- /dev/null +++ b/bin/autolinter @@ -0,0 +1,7 @@ +#!/bin/bash + +# Install the required scripts with +# pip install autoflake autopep8 isort +autoflake ./graphql/ ./tests/ -r --remove-unused-variables --in-place +autopep8 ./tests/ ./graphql/ -r --in-place --experimental --aggressive --max-line-length 120 +isort -rc ./tests/ ./graphql/ diff --git a/graphql/core/execution/base.py b/graphql/core/execution/base.py index c98a832a..66c6a269 100644 --- a/graphql/core/execution/base.py +++ b/graphql/core/execution/base.py @@ -2,19 +2,10 @@ from ..error import GraphQLError from ..language import ast from ..pyutils.defer import DeferredException -from ..type.definition import ( - GraphQLInterfaceType, - GraphQLUnionType, -) -from ..type.directives import ( - GraphQLIncludeDirective, - GraphQLSkipDirective, -) -from ..type.introspection import ( - SchemaMetaFieldDef, - TypeMetaFieldDef, - TypeNameMetaFieldDef, -) +from ..type.definition import GraphQLInterfaceType, GraphQLUnionType +from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective +from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, + TypeNameMetaFieldDef) from ..utils.type_from_ast import type_from_ast from .values import get_argument_values, get_variable_values @@ -97,7 +88,7 @@ def __init__(self, data=None, errors=None, invalid=False): errors = [ error.value if isinstance(error, DeferredException) else error for error in errors - ] + ] self.errors = errors @@ -170,7 +161,9 @@ def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names fields[name].append(selection) elif isinstance(selection, ast.InlineFragment): - if not should_include_node(ctx, directives) or not does_fragment_condition_match(ctx, selection, runtime_type): + if not should_include_node( + ctx, directives) or not does_fragment_condition_match( + ctx, selection, runtime_type): continue collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names) @@ -255,6 +248,7 @@ def get_field_entry_key(node): class ResolveInfo(object): + def __init__(self, field_name, field_asts, return_type, parent_type, context): self.field_name = field_name self.field_asts = field_asts diff --git a/graphql/core/execution/executor.py b/graphql/core/execution/executor.py index 6e73387c..e2b38127 100644 --- a/graphql/core/execution/executor.py +++ b/graphql/core/execution/executor.py @@ -6,15 +6,19 @@ from ..language.parser import parse from ..language.source import Source from ..pyutils.default_ordered_dict import DefaultOrderedDict -from ..pyutils.defer import Deferred, DeferredDict, DeferredList, defer, succeed -from ..type import GraphQLEnumType, GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, \ - GraphQLScalarType, GraphQLUnionType +from ..pyutils.defer import (Deferred, DeferredDict, DeferredList, defer, + succeed) +from ..type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLUnionType) from ..validation import validate -from .base import ExecutionContext, ExecutionResult, ResolveInfo, Undefined, collect_fields, default_resolve_fn, \ - get_field_def, get_operation_root_type +from .base import (ExecutionContext, ExecutionResult, ResolveInfo, Undefined, + collect_fields, default_resolve_fn, get_field_def, + get_operation_root_type) class Executor(object): + def __init__(self, execution_middlewares=None, default_resolver=default_resolve_fn, map_type=dict): assert issubclass(map_type, collections.MutableMapping) diff --git a/graphql/core/execution/middlewares/asyncio.py b/graphql/core/execution/middlewares/asyncio.py index f841f708..ef95602e 100644 --- a/graphql/core/execution/middlewares/asyncio.py +++ b/graphql/core/execution/middlewares/asyncio.py @@ -17,6 +17,7 @@ def handle_future_result(future): class AsyncioExecutionMiddleware(object): + @staticmethod def run_resolve_fn(resolver, original_resolver): result = resolver() diff --git a/graphql/core/execution/middlewares/gevent.py b/graphql/core/execution/middlewares/gevent.py index 7af14963..1b000c02 100644 --- a/graphql/core/execution/middlewares/gevent.py +++ b/graphql/core/execution/middlewares/gevent.py @@ -30,6 +30,7 @@ def resolve_something(context, _*): class GeventExecutionMiddleware(object): + @staticmethod def run_resolve_fn(resolver, original_resolver): if resolver_has_tag(original_resolver, 'run_in_greenlet'): diff --git a/graphql/core/execution/middlewares/sync.py b/graphql/core/execution/middlewares/sync.py index 2440fccf..a0fa8bbc 100644 --- a/graphql/core/execution/middlewares/sync.py +++ b/graphql/core/execution/middlewares/sync.py @@ -3,6 +3,7 @@ class SynchronousExecutionMiddleware(object): + @staticmethod def run_resolve_fn(resolver, original_resolver): result = resolver() diff --git a/graphql/core/execution/values.py b/graphql/core/execution/values.py index 606e61b8..c31f2145 100644 --- a/graphql/core/execution/values.py +++ b/graphql/core/execution/values.py @@ -1,16 +1,12 @@ import collections import json + from six import string_types + from ..error import GraphQLError from ..language.printer import print_ast -from ..type import ( - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - GraphQLScalarType, - is_input_type -) +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType, is_input_type) from ..utils.is_valid_value import is_valid_value from ..utils.type_from_ast import type_from_ast from ..utils.value_from_ast import value_from_ast diff --git a/graphql/core/language/error.py b/graphql/core/language/error.py index 2b5cd0d7..755e9af5 100644 --- a/graphql/core/language/error.py +++ b/graphql/core/language/error.py @@ -5,6 +5,7 @@ class LanguageError(GraphQLError): + def __init__(self, source, position, description): location = get_location(source, position) super(LanguageError, self).__init__( diff --git a/graphql/core/language/lexer.py b/graphql/core/language/lexer.py index fd27a42a..9c5066c6 100644 --- a/graphql/core/language/lexer.py +++ b/graphql/core/language/lexer.py @@ -1,5 +1,7 @@ import json + from six import unichr + from .error import LanguageError __all__ = ['Token', 'Lexer', 'TokenKind', diff --git a/graphql/core/language/parser.py b/graphql/core/language/parser.py index bc344812..2d84aedc 100644 --- a/graphql/core/language/parser.py +++ b/graphql/core/language/parser.py @@ -1,4 +1,5 @@ from six import string_types + from . import ast from .error import LanguageError from .lexer import Lexer, TokenKind, get_token_desc, get_token_kind_desc diff --git a/graphql/core/language/printer.py b/graphql/core/language/printer.py index e39b33b6..446fe0f8 100644 --- a/graphql/core/language/printer.py +++ b/graphql/core/language/printer.py @@ -1,4 +1,5 @@ import json + from .visitor import Visitor, visit __all__ = ['print_ast'] diff --git a/graphql/core/language/visitor.py b/graphql/core/language/visitor.py index 201dbcdd..358dbc7d 100644 --- a/graphql/core/language/visitor.py +++ b/graphql/core/language/visitor.py @@ -1,10 +1,10 @@ from copy import copy import six + from . import ast from .visitor_meta import QUERY_DOCUMENT_KEYS, VisitorMeta - BREAK = object() REMOVE = object() diff --git a/graphql/core/language/visitor_meta.py b/graphql/core/language/visitor_meta.py index 66040c48..d540f513 100644 --- a/graphql/core/language/visitor_meta.py +++ b/graphql/core/language/visitor_meta.py @@ -46,6 +46,7 @@ class VisitorMeta(type): + def __new__(cls, name, bases, attrs): enter_handlers = {} leave_handlers = {} diff --git a/graphql/core/pyutils/defer.py b/graphql/core/pyutils/defer.py index 5ab41fca..0bafe88f 100644 --- a/graphql/core/pyutils/defer.py +++ b/graphql/core/pyutils/defer.py @@ -61,6 +61,7 @@ # THE SOFTWARE. import collections import sys + from six import reraise __all__ = ("Deferred", "AlreadyCalledDeferred", "DeferredException", @@ -511,6 +512,7 @@ def _cb_deferred(self, result, key, succeeded): class DeferredDict(_ResultCollector): + def __init__(self, mapping): super(DeferredDict, self).__init__() assert isinstance(mapping, collections.Mapping) @@ -519,6 +521,7 @@ def __init__(self, mapping): class DeferredList(_ResultCollector): + def __init__(self, sequence): super(DeferredList, self).__init__() assert isinstance(sequence, collections.Sequence) diff --git a/graphql/core/type/definition.py b/graphql/core/type/definition.py index d2281034..430fcfbe 100644 --- a/graphql/core/type/definition.py +++ b/graphql/core/type/definition.py @@ -1,6 +1,7 @@ import collections import copy import re + from ..language import ast @@ -466,7 +467,8 @@ def define_types(union_type, types): if callable(types): types = types() - assert isinstance(types, (list, tuple)) and len(types) > 0, 'Must provide types for Union {}.'.format(union_type.name) + assert isinstance(types, (list, tuple)) and len( + types) > 0, 'Must provide types for Union {}.'.format(union_type.name) has_resolve_type_fn = callable(union_type._resolve_type) for type in types: diff --git a/graphql/core/type/introspection.py b/graphql/core/type/introspection.py index 89a140fc..ed916c2b 100644 --- a/graphql/core/type/introspection.py +++ b/graphql/core/type/introspection.py @@ -1,19 +1,12 @@ from collections import OrderedDict + from ..language.printer import print_ast from ..utils.ast_from_value import ast_from_value -from .definition import ( - GraphQLArgument, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLField, - GraphQLInputObjectType, - GraphQLInterfaceType, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLScalarType, - GraphQLUnionType, -) +from .definition import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, + GraphQLField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, GraphQLNonNull, + GraphQLObjectType, GraphQLScalarType, + GraphQLUnionType) from .scalars import GraphQLBoolean, GraphQLString __Schema = GraphQLObjectType( diff --git a/graphql/core/type/scalars.py b/graphql/core/type/scalars.py index 511c55c8..83a6009e 100644 --- a/graphql/core/type/scalars.py +++ b/graphql/core/type/scalars.py @@ -1,10 +1,6 @@ from six import text_type -from ..language.ast import ( - BooleanValue, - FloatValue, - IntValue, - StringValue, -) + +from ..language.ast import BooleanValue, FloatValue, IntValue, StringValue from .definition import GraphQLScalarType # Integers are only safe when between -(2^53 - 1) and 2^53 - 1 due to being diff --git a/graphql/core/type/schema.py b/graphql/core/type/schema.py index fe7e3ec8..d3045766 100644 --- a/graphql/core/type/schema.py +++ b/graphql/core/type/schema.py @@ -1,15 +1,12 @@ from collections import OrderedDict -from .definition import ( - GraphQLInputObjectType, - GraphQLInterfaceType, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLUnionType, -) -from .directives import GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective -from .introspection import IntrospectionSchema + from ..utils.type_comparators import is_equal_type, is_type_sub_type_of +from .definition import (GraphQLInputObjectType, GraphQLInterfaceType, + GraphQLList, GraphQLNonNull, GraphQLObjectType, + GraphQLUnionType) +from .directives import (GraphQLDirective, GraphQLIncludeDirective, + GraphQLSkipDirective) +from .introspection import IntrospectionSchema class GraphQLSchema(object): @@ -51,7 +48,7 @@ def __init__(self, query, mutation=None, subscription=None, directives=None): assert all(isinstance(d, GraphQLDirective) for d in directives), \ 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( directives - ) + ) self._directives = directives diff --git a/graphql/core/utils/ast_from_value.py b/graphql/core/utils/ast_from_value.py index 77b2ac25..b7c0c6d0 100644 --- a/graphql/core/utils/ast_from_value.py +++ b/graphql/core/utils/ast_from_value.py @@ -3,13 +3,10 @@ import sys from six import string_types + from ..language import ast -from ..type.definition import ( - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, -) +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLList, GraphQLNonNull) from ..type.scalars import GraphQLFloat diff --git a/graphql/core/utils/ast_to_code.py b/graphql/core/utils/ast_to_code.py index 8923fc40..8c307eea 100644 --- a/graphql/core/utils/ast_to_code.py +++ b/graphql/core/utils/ast_to_code.py @@ -7,7 +7,9 @@ def ast_to_code(ast, indent=0): Converts an ast into a python code representation of the AST. """ code = [] - append = lambda line: code.append((' ' * indent) + line) + + def append(line): + code.append((' ' * indent) + line) if isinstance(ast, Node): append('ast.{}('.format(ast.__class__.__name__)) diff --git a/graphql/core/utils/build_ast_schema.py b/graphql/core/utils/build_ast_schema.py index 48e3d612..f68684a9 100644 --- a/graphql/core/utils/build_ast_schema.py +++ b/graphql/core/utils/build_ast_schema.py @@ -1,26 +1,12 @@ from collections import OrderedDict from ..language import ast -from ..type import ( - GraphQLArgument, - GraphQLBoolean, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLField, - GraphQLFloat, - GraphQLID, - GraphQLInputObjectField, - GraphQLInputObjectType, - GraphQLInt, - GraphQLInterfaceType, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLScalarType, - GraphQLSchema, - GraphQLString, - GraphQLUnionType, -) +from ..type import (GraphQLArgument, GraphQLBoolean, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, GraphQLFloat, GraphQLID, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLString, GraphQLUnionType) from ..utils.value_from_ast import value_from_ast @@ -41,8 +27,10 @@ def _get_inner_type_name(type_ast): return type_ast.name.value -_false = lambda *_: False -_none = lambda *_: None +def _false(*_): return False + + +def _none(*_): return None def build_ast_schema(document, query_type_name, mutation_type_name=None, subscription_type_name=None): diff --git a/graphql/core/utils/build_client_schema.py b/graphql/core/utils/build_client_schema.py index 3a3770e1..649e0ca0 100644 --- a/graphql/core/utils/build_client_schema.py +++ b/graphql/core/utils/build_client_schema.py @@ -1,33 +1,22 @@ from collections import OrderedDict + from ..language.parser import parse_value -from ..type import ( - GraphQLArgument, - GraphQLBoolean, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLField, - GraphQLFloat, - GraphQLID, - GraphQLInputObjectField, - GraphQLInputObjectType, - GraphQLInt, - GraphQLInterfaceType, - GraphQLList, - GraphQLNonNull, - GraphQLObjectType, - GraphQLScalarType, - GraphQLSchema, - GraphQLString, - GraphQLUnionType, - is_input_type, - is_output_type -) +from ..type import (GraphQLArgument, GraphQLBoolean, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, GraphQLFloat, GraphQLID, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLString, GraphQLUnionType, + is_input_type, is_output_type) from ..type.directives import GraphQLDirective from ..type.introspection import TypeKind from .value_from_ast import value_from_ast -_none = lambda *_: None -_false = lambda *_: False + +def _false(*_): return False + + +def _none(*_): return None def no_execution(*args): @@ -221,7 +210,8 @@ def build_directive(directive_introspection): get_named_type(type_introspection_name) query_type = get_object_type(schema_introspection['queryType']) - mutation_type = get_object_type(schema_introspection['mutationType']) if schema_introspection.get('mutationType') else None + mutation_type = get_object_type( + schema_introspection['mutationType']) if schema_introspection.get('mutationType') else None subscription_type = get_object_type(schema_introspection['subscriptionType']) if \ schema_introspection.get('subscriptionType') else None diff --git a/graphql/core/utils/extend_schema.py b/graphql/core/utils/extend_schema.py index 2d75fde2..c8aa42f2 100644 --- a/graphql/core/utils/extend_schema.py +++ b/graphql/core/utils/extend_schema.py @@ -1,8 +1,7 @@ from collections import OrderedDict, defaultdict -from graphql.core.language import ast - from ..error import GraphQLError +from ..language import ast from ..type.definition import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, GraphQLField, GraphQLInputObjectField, GraphQLInputObjectType, diff --git a/graphql/core/utils/get_field_def.py b/graphql/core/utils/get_field_def.py index 16d9e061..4250a558 100644 --- a/graphql/core/utils/get_field_def.py +++ b/graphql/core/utils/get_field_def.py @@ -1,9 +1,7 @@ -from ..type.definition import ( - GraphQLInterfaceType, - GraphQLObjectType, - GraphQLUnionType, -) -from ..type.introspection import SchemaMetaFieldDef, TypeMetaFieldDef, TypeNameMetaFieldDef +from ..type.definition import (GraphQLInterfaceType, GraphQLObjectType, + GraphQLUnionType) +from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, + TypeNameMetaFieldDef) def get_field_def(schema, parent_type, field_ast): diff --git a/graphql/core/utils/is_valid_literal_value.py b/graphql/core/utils/is_valid_literal_value.py index 1138244e..52c9a88b 100644 --- a/graphql/core/utils/is_valid_literal_value.py +++ b/graphql/core/utils/is_valid_literal_value.py @@ -1,12 +1,7 @@ from ..language import ast from ..language.printer import print_ast -from ..type.definition import ( - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - GraphQLScalarType, -) +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLList, GraphQLNonNull, GraphQLScalarType) _empty_list = [] @@ -52,8 +47,10 @@ def is_valid_literal_value(type, value_ast): errors.append(u'In field "{}": Unknown field.'.format(provided_field_ast.name.value)) field_ast_map = {field_ast.name.value: field_ast for field_ast in field_asts} - get_field_ast_value = lambda field_name: field_ast_map[field_name].value \ - if field_name in field_ast_map else None + + def get_field_ast_value(field_name): + if field_name in field_ast_map: + return field_ast_map[field_name].value for field_name, field in fields.items(): subfield_errors = is_valid_literal_value(field.type, get_field_ast_value(field_name)) diff --git a/graphql/core/utils/is_valid_value.py b/graphql/core/utils/is_valid_value.py index bdb86326..07c02553 100644 --- a/graphql/core/utils/is_valid_value.py +++ b/graphql/core/utils/is_valid_value.py @@ -4,14 +4,11 @@ import collections import json + from six import string_types -from ..type import ( - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - GraphQLScalarType, -) + +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType) _empty_list = [] diff --git a/graphql/core/utils/schema_printer.py b/graphql/core/utils/schema_printer.py index 53ecf469..a6b22702 100644 --- a/graphql/core/utils/schema_printer.py +++ b/graphql/core/utils/schema_printer.py @@ -1,12 +1,7 @@ from ..language.printer import print_ast -from ..type.definition import ( - GraphQLEnumType, - GraphQLInputObjectType, - GraphQLInterfaceType, - GraphQLObjectType, - GraphQLScalarType, - GraphQLUnionType -) +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLObjectType, + GraphQLScalarType, GraphQLUnionType) from .ast_from_value import ast_from_value diff --git a/graphql/core/utils/type_comparators.py b/graphql/core/utils/type_comparators.py index 570fa77b..9d7a55f7 100644 --- a/graphql/core/utils/type_comparators.py +++ b/graphql/core/utils/type_comparators.py @@ -1,9 +1,5 @@ -from ..type.definition import ( - is_abstract_type, - GraphQLObjectType, - GraphQLList, - GraphQLNonNull, -) +from ..type.definition import (GraphQLList, GraphQLNonNull, GraphQLObjectType, + is_abstract_type) def is_equal_type(type_a, type_b): @@ -37,7 +33,8 @@ def is_type_sub_type_of(maybe_subtype, super_type): elif isinstance(maybe_subtype, GraphQLList): return False - if is_abstract_type(super_type) and isinstance(maybe_subtype, GraphQLObjectType) and super_type.is_possible_type(maybe_subtype): + if is_abstract_type(super_type) and isinstance(maybe_subtype, + GraphQLObjectType) and super_type.is_possible_type(maybe_subtype): return True return False diff --git a/graphql/core/utils/type_from_ast.py b/graphql/core/utils/type_from_ast.py index e52c907e..8689f27a 100644 --- a/graphql/core/utils/type_from_ast.py +++ b/graphql/core/utils/type_from_ast.py @@ -1,8 +1,5 @@ from ..language import ast -from ..type.definition import ( - GraphQLList, - GraphQLNonNull, -) +from ..type.definition import GraphQLList, GraphQLNonNull def type_from_ast(schema, input_type_ast): diff --git a/graphql/core/utils/type_info.py b/graphql/core/utils/type_info.py index 0f68ef09..b701453d 100644 --- a/graphql/core/utils/type_info.py +++ b/graphql/core/utils/type_info.py @@ -1,14 +1,9 @@ import six from ..language import visitor_meta -from ..type.definition import ( - GraphQLInputObjectType, - GraphQLList, - get_named_type, - get_nullable_type, - is_composite_type, -) - +from ..type.definition import (GraphQLInputObjectType, GraphQLList, + get_named_type, get_nullable_type, + is_composite_type) from .get_field_def import get_field_def from .type_from_ast import type_from_ast diff --git a/graphql/core/utils/value_from_ast.py b/graphql/core/utils/value_from_ast.py index b3e240db..6ae89752 100644 --- a/graphql/core/utils/value_from_ast.py +++ b/graphql/core/utils/value_from_ast.py @@ -1,5 +1,6 @@ from ..language import ast -from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, GraphQLNonNull, GraphQLScalarType) +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType) def value_from_ast(value_ast, type, variables=None): diff --git a/graphql/core/validation/context.py b/graphql/core/validation/context.py index eb507b90..8228b437 100644 --- a/graphql/core/validation/context.py +++ b/graphql/core/validation/context.py @@ -1,5 +1,5 @@ -from ..language.ast import FragmentDefinition, FragmentSpread, VariableDefinition, Variable, OperationDefinition -from ..utils.type_info import TypeInfo +from ..language.ast import (FragmentDefinition, FragmentSpread, + OperationDefinition) from ..language.visitor import Visitor, visit from .visitor import TypeInfoVisitor @@ -28,7 +28,8 @@ def enter_Variable(self, node, key, parent, path, ancestors): class ValidationContext(object): - __slots__ = '_schema', '_ast', '_type_info', '_errors', '_fragments', '_fragment_spreads', '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages' + __slots__ = ('_schema', '_ast', '_type_info', '_errors', '_fragments', '_fragment_spreads', + '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages') def __init__(self, schema, ast, type_info): self._schema = schema diff --git a/graphql/core/validation/rules/arguments_of_correct_type.py b/graphql/core/validation/rules/arguments_of_correct_type.py index bc0f2dfe..011fae79 100644 --- a/graphql/core/validation/rules/arguments_of_correct_type.py +++ b/graphql/core/validation/rules/arguments_of_correct_type.py @@ -5,6 +5,7 @@ class ArgumentsOfCorrectType(ValidationRule): + def enter_Argument(self, node, key, parent, path, ancestors): arg_def = self.context.get_argument() if arg_def: diff --git a/graphql/core/validation/rules/default_values_of_correct_type.py b/graphql/core/validation/rules/default_values_of_correct_type.py index 98568362..ad6346b4 100644 --- a/graphql/core/validation/rules/default_values_of_correct_type.py +++ b/graphql/core/validation/rules/default_values_of_correct_type.py @@ -6,6 +6,7 @@ class DefaultValuesOfCorrectType(ValidationRule): + def enter_VariableDefinition(self, node, key, parent, path, ancestors): name = node.variable.name.value default_value = node.default_value diff --git a/graphql/core/validation/rules/fields_on_correct_type.py b/graphql/core/validation/rules/fields_on_correct_type.py index 5d81f45a..1ec8ea8d 100644 --- a/graphql/core/validation/rules/fields_on_correct_type.py +++ b/graphql/core/validation/rules/fields_on_correct_type.py @@ -3,6 +3,7 @@ class FieldsOnCorrectType(ValidationRule): + def enter_Field(self, node, key, parent, path, ancestors): type = self.context.get_parent_type() if not type: diff --git a/graphql/core/validation/rules/fragments_on_composite_types.py b/graphql/core/validation/rules/fragments_on_composite_types.py index fa819c3b..a95e247c 100644 --- a/graphql/core/validation/rules/fragments_on_composite_types.py +++ b/graphql/core/validation/rules/fragments_on_composite_types.py @@ -5,6 +5,7 @@ class FragmentsOnCompositeTypes(ValidationRule): + def enter_InlineFragment(self, node, key, parent, path, ancestors): type = self.context.get_type() diff --git a/graphql/core/validation/rules/known_argument_names.py b/graphql/core/validation/rules/known_argument_names.py index c899fe12..41f797a7 100644 --- a/graphql/core/validation/rules/known_argument_names.py +++ b/graphql/core/validation/rules/known_argument_names.py @@ -4,6 +4,7 @@ class KnownArgumentNames(ValidationRule): + def enter_Argument(self, node, key, parent, path, ancestors): argument_of = ancestors[-1] diff --git a/graphql/core/validation/rules/known_directives.py b/graphql/core/validation/rules/known_directives.py index 4af444ed..48c87cea 100644 --- a/graphql/core/validation/rules/known_directives.py +++ b/graphql/core/validation/rules/known_directives.py @@ -4,6 +4,7 @@ class KnownDirectives(ValidationRule): + def enter_Directive(self, node, key, parent, path, ancestors): directive_def = next(( definition for definition in self.context.get_schema().get_directives() diff --git a/graphql/core/validation/rules/known_fragment_names.py b/graphql/core/validation/rules/known_fragment_names.py index 5e7d35d8..6c7375e3 100644 --- a/graphql/core/validation/rules/known_fragment_names.py +++ b/graphql/core/validation/rules/known_fragment_names.py @@ -3,6 +3,7 @@ class KnownFragmentNames(ValidationRule): + def enter_FragmentSpread(self, node, key, parent, path, ancestors): fragment_name = node.name.value fragment = self.context.get_fragment(fragment_name) diff --git a/graphql/core/validation/rules/known_type_names.py b/graphql/core/validation/rules/known_type_names.py index 66138dbe..d34e0970 100644 --- a/graphql/core/validation/rules/known_type_names.py +++ b/graphql/core/validation/rules/known_type_names.py @@ -3,6 +3,7 @@ class KnownTypeNames(ValidationRule): + def enter_NamedType(self, node, *args): type_name = node.name.value type = self.context.get_schema().get_type(type_name) diff --git a/graphql/core/validation/rules/no_fragment_cycles.py b/graphql/core/validation/rules/no_fragment_cycles.py index 94e752c2..33a49f04 100644 --- a/graphql/core/validation/rules/no_fragment_cycles.py +++ b/graphql/core/validation/rules/no_fragment_cycles.py @@ -1,6 +1,4 @@ from ...error import GraphQLError -from ...language import ast -from ...language.visitor import Visitor, visit from .base import ValidationRule @@ -50,7 +48,7 @@ def detect_cycle_recursive(self, fragment): spread_name, [s.name.value for s in cycle_path] ), - cycle_path+[spread_node] + cycle_path + [spread_node] )) self.spread_path_index_by_name[fragment_name] = None diff --git a/graphql/core/validation/rules/no_undefined_variables.py b/graphql/core/validation/rules/no_undefined_variables.py index 75493a0f..81c9c495 100644 --- a/graphql/core/validation/rules/no_undefined_variables.py +++ b/graphql/core/validation/rules/no_undefined_variables.py @@ -1,5 +1,4 @@ from ...error import GraphQLError -from ...language import ast from .base import ValidationRule diff --git a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py index 735b3035..6bae06a6 100644 --- a/graphql/core/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/core/validation/rules/overlapping_fields_can_be_merged.py @@ -5,13 +5,10 @@ from ...language.printer import print_ast from ...pyutils.default_ordered_dict import DefaultOrderedDict from ...pyutils.pair_set import PairSet -from ...type.definition import ( - GraphQLInterfaceType, - GraphQLObjectType, - get_named_type, -) -from ...utils.type_from_ast import type_from_ast +from ...type.definition import (GraphQLInterfaceType, GraphQLObjectType, + get_named_type) from ...utils.type_comparators import is_equal_type +from ...utils.type_from_ast import type_from_ast from .base import ValidationRule @@ -53,8 +50,8 @@ def find_conflict(self, response_name, field1, field2): # in the current state of the schema, then perhaps in some future version, # thus may not safely diverge. if parent_type1 != parent_type2 and \ - isinstance(parent_type1, GraphQLObjectType) and \ - isinstance(parent_type2, GraphQLObjectType): + isinstance(parent_type1, GraphQLObjectType) and \ + isinstance(parent_type2, GraphQLObjectType): return if self.compared_set.has(ast1, ast2): @@ -125,7 +122,13 @@ def leave_SelectionSet(self, node, key, parent, path, ancestors): conflicts = self.find_conflicts(field_map) if conflicts: for (reason_name, reason), fields1, fields2 in conflicts: - self.context.report_error(GraphQLError(self.fields_conflict_message(reason_name, reason), list(fields1)+list(fields2))) + self.context.report_error( + GraphQLError( + self.fields_conflict_message( + reason_name, + reason), + list(fields1) + + list(fields2))) @staticmethod def same_type(type1, type2): diff --git a/graphql/core/validation/rules/possible_fragment_spreads.py b/graphql/core/validation/rules/possible_fragment_spreads.py index 49fd899f..a44e3143 100644 --- a/graphql/core/validation/rules/possible_fragment_spreads.py +++ b/graphql/core/validation/rules/possible_fragment_spreads.py @@ -1,14 +1,12 @@ from ...error import GraphQLError -from ...type.definition import ( - GraphQLInterfaceType, - GraphQLObjectType, - GraphQLUnionType, -) +from ...type.definition import (GraphQLInterfaceType, GraphQLObjectType, + GraphQLUnionType) from ...utils.type_from_ast import type_from_ast from .base import ValidationRule class PossibleFragmentSpreads(ValidationRule): + def enter_InlineFragment(self, node, key, parent, path, ancestors): frag_type = self.context.get_type() parent_type = self.context.get_parent_type() diff --git a/graphql/core/validation/rules/provided_non_null_arguments.py b/graphql/core/validation/rules/provided_non_null_arguments.py index d67d75e5..2d41a5bd 100644 --- a/graphql/core/validation/rules/provided_non_null_arguments.py +++ b/graphql/core/validation/rules/provided_non_null_arguments.py @@ -4,6 +4,7 @@ class ProvidedNonNullArguments(ValidationRule): + def leave_Field(self, node, key, parent, path, ancestors): field_def = self.context.get_field_def() if not field_def: diff --git a/graphql/core/validation/rules/scalar_leafs.py b/graphql/core/validation/rules/scalar_leafs.py index d4794c2e..531a1d23 100644 --- a/graphql/core/validation/rules/scalar_leafs.py +++ b/graphql/core/validation/rules/scalar_leafs.py @@ -4,6 +4,7 @@ class ScalarLeafs(ValidationRule): + def enter_Field(self, node, key, parent, path, ancestors): type = self.context.get_type() diff --git a/graphql/core/validation/rules/variables_are_input_types.py b/graphql/core/validation/rules/variables_are_input_types.py index 4ea29022..f510fbba 100644 --- a/graphql/core/validation/rules/variables_are_input_types.py +++ b/graphql/core/validation/rules/variables_are_input_types.py @@ -6,6 +6,7 @@ class VariablesAreInputTypes(ValidationRule): + def enter_VariableDefinition(self, node, key, parent, path, ancestors): type = type_from_ast(self.context.get_schema(), node.type) diff --git a/graphql/core/validation/rules/variables_in_allowed_position.py b/graphql/core/validation/rules/variables_in_allowed_position.py index 11e45580..67d9c8ba 100644 --- a/graphql/core/validation/rules/variables_in_allowed_position.py +++ b/graphql/core/validation/rules/variables_in_allowed_position.py @@ -1,9 +1,7 @@ from ...error import GraphQLError -from ...type.definition import ( - GraphQLNonNull -) -from ...utils.type_from_ast import type_from_ast +from ...type.definition import GraphQLNonNull from ...utils.type_comparators import is_type_sub_type_of +from ...utils.type_from_ast import type_from_ast from .base import ValidationRule diff --git a/graphql/core/validation/visitor.py b/graphql/core/validation/visitor.py index 70f6f630..ea3351fb 100644 --- a/graphql/core/validation/visitor.py +++ b/graphql/core/validation/visitor.py @@ -25,7 +25,7 @@ class ParallelVisitor(Visitor): def __init__(self, visitors): self.visitors = visitors - self.skipping = [None]*len(visitors) + self.skipping = [None] * len(visitors) def enter(self, node, key, parent, path, ancestors): for i, visitor in enumerate(self.visitors): diff --git a/tests/core_execution/test_abstract.py b/tests/core_execution/test_abstract.py index bad9520b..9bf9021b 100644 --- a/tests/core_execution/test_abstract.py +++ b/tests/core_execution/test_abstract.py @@ -1,22 +1,26 @@ from graphql.core import graphql -from graphql.core.type import GraphQLString, GraphQLBoolean, GraphQLSchema -from graphql.core.type.definition import GraphQLInterfaceType, GraphQLField, GraphQLObjectType, GraphQLList, \ - GraphQLUnionType +from graphql.core.type import GraphQLBoolean, GraphQLSchema, GraphQLString +from graphql.core.type.definition import (GraphQLField, GraphQLInterfaceType, + GraphQLList, GraphQLObjectType, + GraphQLUnionType) class Dog(object): + def __init__(self, name, woofs): self.name = name self.woofs = woofs class Cat(object): + def __init__(self, name, meows): self.name = name self.meows = meows class Human(object): + def __init__(self, name): self.name = name diff --git a/tests/core_execution/test_concurrent_executor.py b/tests/core_execution/test_concurrent_executor.py index 49eea030..d62dc559 100644 --- a/tests/core_execution/test_concurrent_executor.py +++ b/tests/core_execution/test_concurrent_executor.py @@ -1,11 +1,15 @@ from collections import OrderedDict + from graphql.core.error import format_error from graphql.core.execution import Executor -from graphql.core.execution.middlewares.sync import SynchronousExecutionMiddleware -from graphql.core.pyutils.defer import succeed, Deferred, fail -from graphql.core.type import (GraphQLSchema, GraphQLObjectType, GraphQLField, - GraphQLArgument, GraphQLList, GraphQLInt, GraphQLString) +from graphql.core.execution.middlewares.sync import \ + SynchronousExecutionMiddleware +from graphql.core.pyutils.defer import Deferred, fail, succeed +from graphql.core.type import (GraphQLArgument, GraphQLField, GraphQLInt, + GraphQLList, GraphQLObjectType, GraphQLSchema, + GraphQLString) from graphql.core.type.definition import GraphQLNonNull + from .utils import raise_callback_results @@ -123,6 +127,7 @@ def handle_result(result): def test_synchronous_executor_doesnt_support_defers_with_nullable_type_getting_set_to_null(): class Data(object): + def promise(self): return succeed('i shouldn\'nt work') @@ -153,6 +158,7 @@ def notPromise(self): def test_synchronous_executor_doesnt_support_defers(): class Data(object): + def promise(self): return succeed('i shouldn\'nt work') @@ -183,6 +189,7 @@ def notPromise(self): def test_executor_defer_failure(): class Data(object): + def promise(self): return fail(Exception('Something bad happened! Sucks :(')) @@ -213,6 +220,7 @@ def notPromise(self): def test_synchronous_executor_will_synchronously_resolve(): class Data(object): + def promise(self): return 'I should work' @@ -248,6 +256,7 @@ def test_synchronous_error_nulls_out_error_subtrees(): ''' class Data: + def sync(self): return 'sync' diff --git a/tests/core_execution/test_default_executor.py b/tests/core_execution/test_default_executor.py index 7cc3edf7..20181f6d 100644 --- a/tests/core_execution/test_default_executor.py +++ b/tests/core_execution/test_default_executor.py @@ -1,4 +1,5 @@ -from graphql.core.execution import get_default_executor, set_default_executor, Executor +from graphql.core.execution import (Executor, get_default_executor, + set_default_executor) def test_get_and_set_default_executor(): diff --git a/tests/core_execution/test_deferred.py b/tests/core_execution/test_deferred.py index abb34244..25032286 100644 --- a/tests/core_execution/test_deferred.py +++ b/tests/core_execution/test_deferred.py @@ -1,7 +1,8 @@ from pytest import raises -from graphql.core.pyutils.defer import Deferred, DeferredException, succeed, fail, DeferredList, DeferredDict, \ - AlreadyCalledDeferred +from graphql.core.pyutils.defer import (AlreadyCalledDeferred, Deferred, + DeferredDict, DeferredException, + DeferredList, fail, succeed) def test_succeed(): @@ -274,4 +275,4 @@ def dummy_errback(deferred_exception): deferred = Deferred() deferred.add_errback(dummy_errback) deferred.errback(OSError()) - assert deferred.result == 'caught' \ No newline at end of file + assert deferred.result == 'caught' diff --git a/tests/core_execution/test_directives.py b/tests/core_execution/test_directives.py index 1644030c..dac626bb 100644 --- a/tests/core_execution/test_directives.py +++ b/tests/core_execution/test_directives.py @@ -1,6 +1,7 @@ from graphql.core.execution import execute from graphql.core.language.parser import parse -from graphql.core.type import GraphQLSchema, GraphQLObjectType, GraphQLField, GraphQLString +from graphql.core.type import (GraphQLField, GraphQLObjectType, GraphQLSchema, + GraphQLString) schema = GraphQLSchema( query=GraphQLObjectType( diff --git a/tests/core_execution/test_executor.py b/tests/core_execution/test_executor.py index 191e4d9a..2e252694 100644 --- a/tests/core_execution/test_executor.py +++ b/tests/core_execution/test_executor.py @@ -1,13 +1,16 @@ -from collections import OrderedDict import json +from collections import OrderedDict + from pytest import raises -from graphql.core.execution import execute, Executor -from graphql.core.execution.middlewares.sync import SynchronousExecutionMiddleware -from graphql.core.language.parser import parse -from graphql.core.type import (GraphQLSchema, GraphQLObjectType, GraphQLField, - GraphQLArgument, GraphQLList, GraphQLInt, GraphQLString, - GraphQLBoolean) + from graphql.core.error import GraphQLError +from graphql.core.execution import Executor, execute +from graphql.core.execution.middlewares.sync import \ + SynchronousExecutionMiddleware +from graphql.core.language.parser import parse +from graphql.core.type import (GraphQLArgument, GraphQLBoolean, GraphQLField, + GraphQLInt, GraphQLList, GraphQLObjectType, + GraphQLSchema, GraphQLString) def test_executes_arbitary_code(): @@ -145,17 +148,17 @@ def test_merges_parallel_fragments(): result = execute(schema, None, ast) assert not result.errors assert result.data == \ - { - 'a': 'Apple', - 'b': 'Banana', - 'c': 'Cherry', - 'deep': { - 'b': 'Banana', - 'c': 'Cherry', - 'deeper': { - 'b': 'Banana', - 'c': 'Cherry'}} - } + { + 'a': 'Apple', + 'b': 'Banana', + 'c': 'Cherry', + 'deep': { + 'b': 'Banana', + 'c': 'Cherry', + 'deeper': { + 'b': 'Banana', + 'c': 'Cherry'}} + } def test_threads_context_correctly(): @@ -219,6 +222,7 @@ def test_nulls_out_error_subtrees(): }''' class Data(object): + def ok(self): return 'ok' @@ -410,10 +414,12 @@ def test_does_not_include_arguments_that_were_not_set(): def test_fails_when_an_is_type_of_check_is_not_met(): class Special(object): + def __init__(self, value): self.value = value class NotSpecial(object): + def __init__(self, value): self.value = value @@ -471,7 +477,7 @@ def test_fails_to_execute_a_query_containing_a_type_definition(): ) with raises(GraphQLError) as excinfo: - result = execute(schema, None, query) + execute(schema, None, query) assert excinfo.value.message == 'GraphQL cannot execute a request containing a ObjectTypeDefinition.' diff --git a/tests/core_execution/test_executor_schema.py b/tests/core_execution/test_executor_schema.py index 2f1987f8..cc5def57 100644 --- a/tests/core_execution/test_executor_schema.py +++ b/tests/core_execution/test_executor_schema.py @@ -1,17 +1,10 @@ from graphql.core.execution import execute from graphql.core.language.parser import parse -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLList, - GraphQLNonNull, - GraphQLInt, - GraphQLString, - GraphQLBoolean, - GraphQLID, -) +from graphql.core.type import (GraphQLArgument, GraphQLBoolean, GraphQLField, + GraphQLID, GraphQLInt, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString) + def test_executes_using_a_schema(): BlogImage = GraphQLObjectType('BlogImage', { @@ -24,13 +17,13 @@ def test_executes_using_a_schema(): 'id': GraphQLField(GraphQLString), 'name': GraphQLField(GraphQLString), 'pic': GraphQLField(BlogImage, - args={ - 'width': GraphQLArgument(GraphQLInt), - 'height': GraphQLArgument(GraphQLInt), - }, - resolver=lambda obj, args, *_: - obj.pic(args['width'], args['height']) - ), + args={ + 'width': GraphQLArgument(GraphQLInt), + 'height': GraphQLArgument(GraphQLInt), + }, + resolver=lambda obj, args, *_: + obj.pic(args['width'], args['height']) + ), 'recentArticle': GraphQLField(BlogArticle), }) @@ -56,6 +49,7 @@ def test_executes_using_a_schema(): BlogSchema = GraphQLSchema(BlogQuery) class Article(object): + def __init__(self, id): self.id = id self.isPublished = True @@ -68,12 +62,15 @@ def __init__(self, id): class Author(object): id = 123 name = 'John Smith' + def pic(self, width, height): return Pic(123, width, height) + @property def recentArticle(self): return Article(1) class Pic(object): + def __init__(self, uid, width, height): self.url = 'cdn://{}'.format(uid) self.width = str(width) diff --git a/tests/core_execution/test_gevent.py b/tests/core_execution/test_gevent.py index 7527a0f1..976a54f4 100644 --- a/tests/core_execution/test_gevent.py +++ b/tests/core_execution/test_gevent.py @@ -1,16 +1,13 @@ # flake8: noqa +import gevent + from graphql.core.error import format_error from graphql.core.execution import Executor -from graphql.core.execution.middlewares.gevent import GeventExecutionMiddleware, run_in_greenlet +from graphql.core.execution.middlewares.gevent import (GeventExecutionMiddleware, + run_in_greenlet) from graphql.core.language.location import SourceLocation -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLString -) - -import gevent +from graphql.core.type import (GraphQLField, GraphQLObjectType, GraphQLSchema, + GraphQLString) def test_gevent_executor(): diff --git a/tests/core_execution/test_lists.py b/tests/core_execution/test_lists.py index 29a7c2ba..cc3f8ab1 100644 --- a/tests/core_execution/test_lists.py +++ b/tests/core_execution/test_lists.py @@ -1,16 +1,12 @@ from collections import namedtuple + from graphql.core.error import format_error -from graphql.core.execution import execute, Executor +from graphql.core.execution import Executor, execute from graphql.core.language.parser import parse -from graphql.core.pyutils.defer import succeed, fail -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLInt, - GraphQLList, - GraphQLNonNull, -) +from graphql.core.pyutils.defer import fail, succeed +from graphql.core.type import (GraphQLField, GraphQLInt, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema) Data = namedtuple('Data', 'test') ast = parse('{ nest { test } }') diff --git a/tests/core_execution/test_middleware.py b/tests/core_execution/test_middleware.py index 32872e60..806e0ed3 100644 --- a/tests/core_execution/test_middleware.py +++ b/tests/core_execution/test_middleware.py @@ -1,4 +1,6 @@ -from graphql.core.execution.middlewares.utils import tag_resolver, resolver_has_tag, merge_resolver_tags +from graphql.core.execution.middlewares.utils import (merge_resolver_tags, + resolver_has_tag, + tag_resolver) def test_tag_resolver(): diff --git a/tests/core_execution/test_mutations.py b/tests/core_execution/test_mutations.py index 05bdb623..df8080e4 100644 --- a/tests/core_execution/test_mutations.py +++ b/tests/core_execution/test_mutations.py @@ -1,15 +1,18 @@ from graphql.core.execution import execute from graphql.core.language.parser import parse -from graphql.core.type import (GraphQLSchema, GraphQLObjectType, GraphQLField, - GraphQLArgument, GraphQLList, GraphQLInt, GraphQLString) +from graphql.core.type import (GraphQLArgument, GraphQLField, GraphQLInt, + GraphQLList, GraphQLObjectType, GraphQLSchema, + GraphQLString) class NumberHolder(object): + def __init__(self, n): self.theNumber = n class Root(object): + def __init__(self, n): self.numberHolder = NumberHolder(n) @@ -27,7 +30,7 @@ def fail_to_change_the_number(self, n): def promise_and_fail_to_change_the_number(self, n): # TODO: async self.fail_to_change_the_number(n) - + NumberHolderType = GraphQLObjectType('NumberHolder', { 'theNumber': GraphQLField(GraphQLInt) diff --git a/tests/core_execution/test_nonnull.py b/tests/core_execution/test_nonnull.py index 2d55374f..2fc4b281 100644 --- a/tests/core_execution/test_nonnull.py +++ b/tests/core_execution/test_nonnull.py @@ -1,9 +1,11 @@ from collections import OrderedDict + from graphql.core.error import format_error +from graphql.core.execution import Executor, execute from graphql.core.language.parser import parse from graphql.core.pyutils.defer import fail, succeed -from graphql.core.type import GraphQLObjectType, GraphQLField, GraphQLString, GraphQLNonNull, GraphQLSchema -from graphql.core.execution import execute, Executor +from graphql.core.type import (GraphQLField, GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString) sync_error = Exception('sync') non_null_sync_error = Exception('nonNullSync') @@ -14,6 +16,7 @@ class ThrowingData(object): + def sync(self): raise sync_error @@ -40,6 +43,7 @@ def nonNullPromiseNest(self): class NullingData(object): + def sync(self): return None diff --git a/tests/core_execution/test_union_interface.py b/tests/core_execution/test_union_interface.py index 3355893d..e8492c13 100644 --- a/tests/core_execution/test_union_interface.py +++ b/tests/core_execution/test_union_interface.py @@ -1,30 +1,27 @@ from graphql.core.execution import execute from graphql.core.language.parser import parse -from graphql.core.type import ( - GraphQLSchema, - GraphQLField, - GraphQLObjectType, - GraphQLInterfaceType, - GraphQLUnionType, - GraphQLList, - GraphQLString, - GraphQLBoolean -) +from graphql.core.type import (GraphQLBoolean, GraphQLField, + GraphQLInterfaceType, GraphQLList, + GraphQLObjectType, GraphQLSchema, GraphQLString, + GraphQLUnionType) class Dog(object): + def __init__(self, name, barks): self.name = name self.barks = barks class Cat(object): + def __init__(self, name, meows): self.name = name self.meows = meows class Person(object): + def __init__(self, name, pets, friends): self.name = name self.pets = pets diff --git a/tests/core_execution/test_variables.py b/tests/core_execution/test_variables.py index 4c70a8e2..98df6b10 100644 --- a/tests/core_execution/test_variables.py +++ b/tests/core_execution/test_variables.py @@ -1,20 +1,14 @@ -from pytest import raises import json + +from pytest import raises + +from graphql.core.error import GraphQLError, format_error from graphql.core.execution import execute from graphql.core.language.parser import parse -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLInputObjectField, - GraphQLInputObjectType, - GraphQLList, - GraphQLString, - GraphQLNonNull, - GraphQLScalarType, -) -from graphql.core.error import GraphQLError, format_error +from graphql.core.type import (GraphQLArgument, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLList, GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString) TestComplexScalar = GraphQLScalarType( name='ComplexScalar', @@ -628,6 +622,7 @@ def test_does_not_allow_unknown_types_to_be_used_as_values(): # noinspection PyMethodMayBeStatic class TestUsesArgumentDefaultValues: + def test_when_no_argument_provided(self): check('{ fieldWithDefaultArgumentValue }', { 'data': { diff --git a/tests/core_execution/utils.py b/tests/core_execution/utils.py index 64580adf..b87d3471 100644 --- a/tests/core_execution/utils.py +++ b/tests/core_execution/utils.py @@ -1,7 +1,9 @@ -from graphql.core.pyutils.defer import Deferred, DeferredException, _passthrough +from graphql.core.pyutils.defer import (Deferred, DeferredException, + _passthrough) class RaisingDeferred(Deferred): + def _next(self): """Process the next callback.""" if self._running or self.paused: diff --git a/tests/core_language/fixtures.py b/tests/core_language/fixtures.py index 18c25daa..25c93fce 100644 --- a/tests/core_language/fixtures.py +++ b/tests/core_language/fixtures.py @@ -98,4 +98,4 @@ extend type Foo { seven(argument: [String]): Type } -""" \ No newline at end of file +""" diff --git a/tests/core_language/test_ast.py b/tests/core_language/test_ast.py index ae5481ff..67fd0ac7 100644 --- a/tests/core_language/test_ast.py +++ b/tests/core_language/test_ast.py @@ -1,5 +1,6 @@ import copy -from graphql.core.language.visitor_meta import VisitorMeta, QUERY_DOCUMENT_KEYS + +from graphql.core.language.visitor_meta import QUERY_DOCUMENT_KEYS, VisitorMeta def test_ast_is_hashable(): diff --git a/tests/core_language/test_lexer.py b/tests/core_language/test_lexer.py index fa9ec22a..37075f6e 100644 --- a/tests/core_language/test_lexer.py +++ b/tests/core_language/test_lexer.py @@ -1,7 +1,8 @@ from pytest import raises + from graphql.core.language.error import LanguageError -from graphql.core.language.source import Source from graphql.core.language.lexer import Lexer, Token, TokenKind +from graphql.core.language.source import Source def lex_one(s): diff --git a/tests/core_language/test_parser.py b/tests/core_language/test_parser.py index bef26377..6a41e6fd 100644 --- a/tests/core_language/test_parser.py +++ b/tests/core_language/test_parser.py @@ -1,9 +1,11 @@ from pytest import raises + +from graphql.core.language import ast from graphql.core.language.error import LanguageError from graphql.core.language.location import SourceLocation +from graphql.core.language.parser import Loc, parse from graphql.core.language.source import Source -from graphql.core.language.parser import parse, Loc -from graphql.core.language import ast + from .fixtures import KITCHEN_SINK @@ -16,12 +18,12 @@ def test_parse_provides_useful_errors(): with raises(LanguageError) as excinfo: parse("""{""") assert ( - u'Syntax Error GraphQL (1:2) Expected Name, found EOF\n' - u'\n' - u'1: {\n' - u' ^\n' - u'' - ) == excinfo.value.message + u'Syntax Error GraphQL (1:2) Expected Name, found EOF\n' + u'\n' + u'1: {\n' + u' ^\n' + u'' + ) == excinfo.value.message assert excinfo.value.positions == [1] assert excinfo.value.locations == [SourceLocation(line=1, column=2)] @@ -183,50 +185,47 @@ def test_parse_creates_ast(): result = parse(source) assert result == \ - ast.Document( - loc=Loc(start=0, end=41, source=source), - definitions= - [ast.OperationDefinition( - loc=Loc(start=0, end=40, source=source), - operation='query', - name=None, - variable_definitions=None, - directives=[], - selection_set=ast.SelectionSet( - loc=Loc(start=0, end=40, source=source), - selections= - [ast.Field( - loc=Loc(start=4, end=38, source=source), - alias=None, - name=ast.Name( - loc=Loc(start=4, end=8, source=source), - value='node'), - arguments=[ast.Argument( - name=ast.Name(loc=Loc(start=9, end=11, source=source), - value='id'), - value=ast.IntValue( + ast.Document( + loc=Loc(start=0, end=41, source=source), + definitions=[ast.OperationDefinition( + loc=Loc(start=0, end=40, source=source), + operation='query', + name=None, + variable_definitions=None, + directives=[], + selection_set=ast.SelectionSet( + loc=Loc(start=0, end=40, source=source), + selections=[ast.Field( + loc=Loc(start=4, end=38, source=source), + alias=None, + name=ast.Name( + loc=Loc(start=4, end=8, source=source), + value='node'), + arguments=[ast.Argument( + name=ast.Name(loc=Loc(start=9, end=11, source=source), + value='id'), + value=ast.IntValue( loc=Loc(start=13, end=14, source=source), value='4'), - loc=Loc(start=9, end=14, source=source))], - directives=[], - selection_set=ast.SelectionSet( - loc=Loc(start=16, end=38, source=source), - selections= - [ast.Field( - loc=Loc(start=22, end=24, source=source), - alias=None, - name=ast.Name( - loc=Loc(start=22, end=24, source=source), - value='id'), - arguments=[], - directives=[], - selection_set=None), - ast.Field( - loc=Loc(start=30, end=34, source=source), - alias=None, - name=ast.Name( - loc=Loc(start=30, end=34, source=source), - value='name'), - arguments=[], - directives=[], - selection_set=None)]))]))]) + loc=Loc(start=9, end=14, source=source))], + directives=[], + selection_set=ast.SelectionSet( + loc=Loc(start=16, end=38, source=source), + selections=[ast.Field( + loc=Loc(start=22, end=24, source=source), + alias=None, + name=ast.Name( + loc=Loc(start=22, end=24, source=source), + value='id'), + arguments=[], + directives=[], + selection_set=None), + ast.Field( + loc=Loc(start=30, end=34, source=source), + alias=None, + name=ast.Name( + loc=Loc(start=30, end=34, source=source), + value='name'), + arguments=[], + directives=[], + selection_set=None)]))]))]) diff --git a/tests/core_language/test_printer.py b/tests/core_language/test_printer.py index 90a396ab..47ffc441 100644 --- a/tests/core_language/test_printer.py +++ b/tests/core_language/test_printer.py @@ -1,8 +1,11 @@ import copy + +from pytest import raises + from graphql.core.language.ast import Field, Name from graphql.core.language.parser import parse from graphql.core.language.printer import print_ast -from pytest import raises + from .fixtures import KITCHEN_SINK diff --git a/tests/core_language/test_schema_parser.py b/tests/core_language/test_schema_parser.py index eb002b56..bfe7ac67 100644 --- a/tests/core_language/test_schema_parser.py +++ b/tests/core_language/test_schema_parser.py @@ -1,4 +1,5 @@ from pytest import raises + from graphql.core import Source, parse from graphql.core.language import ast from graphql.core.language.error import LanguageError diff --git a/tests/core_language/test_schema_printer.py b/tests/core_language/test_schema_printer.py index 4be5af20..9064386f 100644 --- a/tests/core_language/test_schema_printer.py +++ b/tests/core_language/test_schema_printer.py @@ -1,8 +1,11 @@ from copy import deepcopy -from graphql.core import parse + from pytest import raises -from graphql.core.language.printer import print_ast + +from graphql.core import parse from graphql.core.language import ast +from graphql.core.language.printer import print_ast + from .fixtures import SCHEMA_KITCHEN_SINK diff --git a/tests/core_language/test_visitor.py b/tests/core_language/test_visitor.py index d5326930..293d48a2 100644 --- a/tests/core_language/test_visitor.py +++ b/tests/core_language/test_visitor.py @@ -1,6 +1,7 @@ from graphql.core.language.ast import Field, Name, SelectionSet from graphql.core.language.parser import parse -from graphql.core.language.visitor import visit, Visitor, REMOVE, BREAK +from graphql.core.language.visitor import BREAK, REMOVE, Visitor, visit + from .fixtures import KITCHEN_SINK @@ -8,6 +9,7 @@ def test_allows_for_editing_on_enter(): ast = parse('{ a, b, c { a, b, c } }', no_location=True) class TestVisitor(Visitor): + def enter(self, node, *args): if isinstance(node, Field) and node.name.value == 'b': return REMOVE @@ -22,6 +24,7 @@ def test_allows_for_editing_on_leave(): ast = parse('{ a, b, c { a, b, c } }', no_location=True) class TestVisitor(Visitor): + def leave(self, node, *args): if isinstance(node, Field) and node.name.value == 'b': return REMOVE @@ -37,6 +40,7 @@ def test_visits_edited_node(): ast = parse('{ a { x } }') class TestVisitor(Visitor): + def __init__(self): self.did_visit_added_field = False @@ -61,6 +65,7 @@ def test_allows_skipping_a_subtree(): ast = parse('{ a, b { x }, c }') class TestVisitor(Visitor): + def enter(self, node, *args): visited.append(['enter', type(node).__name__, getattr(node, 'value', None)]) if isinstance(node, Field) and node.name.value == 'b': @@ -95,6 +100,7 @@ def test_allows_early_exit_while_visiting(): ast = parse('{ a, b { x }, c }') class TestVisitor(Visitor): + def enter(self, node, *args): visited.append(['enter', type(node).__name__, getattr(node, 'value', None)]) if isinstance(node, Name) and node.value == 'x': @@ -127,6 +133,7 @@ def test_allows_a_named_functions_visitor_api(): ast = parse('{ a, b { x }, c }') class TestVisitor(Visitor): + def enter_Name(self, node, *args): visited.append(['enter', type(node).__name__, getattr(node, 'value', None)]) @@ -155,6 +162,7 @@ def test_visits_kitchen_sink(): ast = parse(KITCHEN_SINK) class TestVisitor(Visitor): + def enter(self, node, key, parent, *args): kind = parent and type(parent).__name__ if kind == 'list': diff --git a/tests/core_language/test_visitor_meta.py b/tests/core_language/test_visitor_meta.py index 32e93d79..a57f9f99 100644 --- a/tests/core_language/test_visitor_meta.py +++ b/tests/core_language/test_visitor_meta.py @@ -4,6 +4,7 @@ def test_visitor_meta_creates_enter_and_leave_handlers(): class MyVisitor(Visitor): + def enter_OperationDefinition(self): pass @@ -16,6 +17,7 @@ def leave_OperationDefinition(self): def test_visitor_inherits_parent_definitions(): class MyVisitor(Visitor): + def enter_OperationDefinition(self): pass @@ -26,6 +28,7 @@ def leave_OperationDefinition(self): assert MyVisitor._get_leave_handler(ast.OperationDefinition) class MyVisitorSubclassed(MyVisitor): + def enter_FragmentDefinition(self): pass diff --git a/tests/core_pyutils/test_default_ordered_dict.py b/tests/core_pyutils/test_default_ordered_dict.py index 1700111e..a8d5fa79 100644 --- a/tests/core_pyutils/test_default_ordered_dict.py +++ b/tests/core_pyutils/test_default_ordered_dict.py @@ -1,8 +1,11 @@ import copy -from graphql.core.pyutils.default_ordered_dict import DefaultOrderedDict -from pytest import raises import pickle +from pytest import raises + +from graphql.core.pyutils.default_ordered_dict import DefaultOrderedDict + + def test_will_missing_will_set_value_from_factory(): d = DefaultOrderedDict(list) f = d['foo'] diff --git a/tests/core_starwars/starwars_schema.py b/tests/core_starwars/starwars_schema.py index 1e446e7a..9ca43fed 100644 --- a/tests/core_starwars/starwars_schema.py +++ b/tests/core_starwars/starwars_schema.py @@ -1,16 +1,10 @@ -from graphql.core.type import ( - GraphQLEnumType, - GraphQLEnumValue, - GraphQLInterfaceType, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLList, - GraphQLNonNull, - GraphQLSchema, - GraphQLString, -) -from .starwars_fixtures import getHero, getHuman, getFriends, getDroid +from graphql.core.type import (GraphQLArgument, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString) + +from .starwars_fixtures import getDroid, getFriends, getHero, getHuman episodeEnum = GraphQLEnumType( 'Episode', diff --git a/tests/core_starwars/test_query.py b/tests/core_starwars/test_query.py index 30142345..9b50c143 100644 --- a/tests/core_starwars/test_query.py +++ b/tests/core_starwars/test_query.py @@ -1,7 +1,8 @@ -from .starwars_schema import StarWarsSchema from graphql.core import graphql from graphql.core.error import format_error +from .starwars_schema import StarWarsSchema + def test_hero_name_query(): query = ''' @@ -70,7 +71,7 @@ def test_nested_query(): 'friends': [ { 'name': 'Luke Skywalker', - 'appearsIn': [ 'NEWHOPE', 'EMPIRE', 'JEDI' ], + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], 'friends': [ { 'name': 'Han Solo', @@ -88,7 +89,7 @@ def test_nested_query(): }, { 'name': 'Han Solo', - 'appearsIn': [ 'NEWHOPE', 'EMPIRE', 'JEDI' ], + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], 'friends': [ { 'name': 'Luke Skywalker', @@ -103,7 +104,7 @@ def test_nested_query(): }, { 'name': 'Leia Organa', - 'appearsIn': [ 'NEWHOPE', 'EMPIRE', 'JEDI' ], + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], 'friends': [ { 'name': 'Luke Skywalker', @@ -265,11 +266,11 @@ def test_duplicate_fields(): 'luke': { 'name': 'Luke Skywalker', 'homePlanet': 'Tatooine', - }, + }, 'leia': { 'name': 'Leia Organa', 'homePlanet': 'Alderaan', - } + } } result = graphql(StarWarsSchema, query) assert not result.errors @@ -295,11 +296,11 @@ def test_use_fragment(): 'luke': { 'name': 'Luke Skywalker', 'homePlanet': 'Tatooine', - }, + }, 'leia': { 'name': 'Leia Organa', 'homePlanet': 'Alderaan', - } + } } result = graphql(StarWarsSchema, query) assert not result.errors diff --git a/tests/core_starwars/test_validation.py b/tests/core_starwars/test_validation.py index bb85e11d..68e16c74 100644 --- a/tests/core_starwars/test_validation.py +++ b/tests/core_starwars/test_validation.py @@ -1,6 +1,7 @@ -from graphql.core.language.source import Source from graphql.core.language.parser import parse +from graphql.core.language.source import Source from graphql.core.validation import validate + from .starwars_schema import StarWarsSchema diff --git a/tests/core_type/test_definition.py b/tests/core_type/test_definition.py index 48641fd6..0f7eb9ac 100644 --- a/tests/core_type/test_definition.py +++ b/tests/core_type/test_definition.py @@ -1,22 +1,13 @@ from collections import OrderedDict + from py.test import raises -from graphql.core.type import ( - GraphQLSchema, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLInputObjectField, - GraphQLInputObjectType, - GraphQLInterfaceType, - GraphQLObjectType, - GraphQLUnionType, - GraphQLList, - GraphQLNonNull, - GraphQLInt, - GraphQLString, - GraphQLBoolean, - GraphQLField, - GraphQLArgument, -) + +from graphql.core.type import (GraphQLArgument, GraphQLBoolean, + GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString, GraphQLUnionType) from graphql.core.type.definition import is_input_type, is_output_type BlogImage = GraphQLObjectType('Image', { diff --git a/tests/core_type/test_enum_type.py b/tests/core_type/test_enum_type.py index e1f29225..d076de76 100644 --- a/tests/core_type/test_enum_type.py +++ b/tests/core_type/test_enum_type.py @@ -1,16 +1,11 @@ from collections import OrderedDict + from pytest import raises -from graphql.core.type import ( - GraphQLEnumType, - GraphQLEnumValue, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLInt, - GraphQLString, - GraphQLSchema -) + from graphql.core import graphql +from graphql.core.type import (GraphQLArgument, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, GraphQLInt, + GraphQLObjectType, GraphQLSchema, GraphQLString) ColorType = GraphQLEnumType( name='Color', @@ -147,7 +142,9 @@ def test_accepts_enum_literals_as_input_arguments_to_mutations(): def test_accepts_enum_literals_as_input_arguments_to_subscriptions(): - result = graphql(Schema, 'subscription x($color: Color!) { subscribeToEnum(color: $color) }', None, {'color': 'GREEN'}) + result = graphql( + Schema, 'subscription x($color: Color!) { subscribeToEnum(color: $color) }', None, { + 'color': 'GREEN'}) assert not result.errors assert result.data == {'subscribeToEnum': 'GREEN'} diff --git a/tests/core_type/test_introspection.py b/tests/core_type/test_introspection.py index a7fafa27..4444b4ac 100644 --- a/tests/core_type/test_introspection.py +++ b/tests/core_type/test_introspection.py @@ -1,23 +1,17 @@ -from collections import OrderedDict import json +from collections import OrderedDict + from graphql.core import graphql from graphql.core.error import format_error -from graphql.core.language.parser import parse from graphql.core.execution import execute -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLInputObjectType, - GraphQLInputObjectField, - GraphQLString, - GraphQLList, - GraphQLEnumType, - GraphQLEnumValue, -) -from graphql.core.validation.rules import ProvidedNonNullArguments +from graphql.core.language.parser import parse +from graphql.core.type import (GraphQLArgument, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLList, GraphQLObjectType, GraphQLSchema, + GraphQLString) from graphql.core.utils.introspection_query import introspection_query +from graphql.core.validation.rules import ProvidedNonNullArguments def test_executes_an_introspection_query(): diff --git a/tests/core_type/test_serialization.py b/tests/core_type/test_serialization.py index 4ce160e2..bfdcde9b 100644 --- a/tests/core_type/test_serialization.py +++ b/tests/core_type/test_serialization.py @@ -1,9 +1,6 @@ -from graphql.core.type import ( - GraphQLInt, - GraphQLFloat, - GraphQLString, - GraphQLBoolean, -) +from graphql.core.type import (GraphQLBoolean, GraphQLFloat, GraphQLInt, + GraphQLString) + def test_serializes_output_int(): assert GraphQLInt.serialize(1) == 1 diff --git a/tests/core_type/test_validation.py b/tests/core_type/test_validation.py index bad1390d..0efea487 100644 --- a/tests/core_type/test_validation.py +++ b/tests/core_type/test_validation.py @@ -1,20 +1,13 @@ -from pytest import raises import re -from graphql.core.type import ( - GraphQLSchema, - GraphQLScalarType, - GraphQLObjectType, - GraphQLInterfaceType, - GraphQLUnionType, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLInputObjectType, - GraphQLInputObjectField, - GraphQLList, - GraphQLNonNull, - GraphQLField, - GraphQLString -) + +from pytest import raises + +from graphql.core.type import (GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString, + GraphQLUnionType) from graphql.core.type.definition import GraphQLArgument _none = lambda *args: None @@ -121,6 +114,7 @@ def schema_with_field_type(t): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ASchemaMustHaveObjectRootTypes: + def test_accepts_a_schema_whose_query_type_is_an_object_type(self): assert GraphQLSchema(query=SomeObjectType) @@ -174,6 +168,7 @@ def test_rejects_a_schema_whose_directives_are_incorrectly_typed(self): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ASchemaMustContainUniquelyNamedTypes: + def test_it_rejects_a_schema_which_defines_a_builtin_type(self): FakeString = GraphQLScalarType( name='String', @@ -232,6 +227,7 @@ def test_it_rejects_a_schema_which_have_same_named_objects_implementing_an_inter # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectsMustHaveFields: + def test_accepts_an_object_type_with_fields_object(self): assert schema_with_field_type(GraphQLObjectType( name='SomeObject', @@ -332,6 +328,7 @@ def test_rejects_an_object_type_with_a_field_function_with_an_invalid_value(self # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_FieldArgsMustBeProperlyNamed: + def test_accepts_field_args_with_valid_names(self): assert schema_with_field_type(GraphQLObjectType( name='SomeObject', @@ -364,6 +361,7 @@ def test_reject_field_args_with_invalid_names(self): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_FieldArgsMustBeObjects: + def test_accepts_an_object_with_field_args(self): assert schema_with_field_type(GraphQLObjectType( name='SomeObject', @@ -409,6 +407,7 @@ def test_rejects_an_object_with_incorrectly_typed_field_args_with_an_invalid_val # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectInterfacesMustBeArray: + def test_accepts_an_object_type_with_array_interface(self): AnotherInterfaceType = GraphQLInterfaceType( name='AnotherInterface', @@ -460,6 +459,7 @@ def test_rejects_an_object_type_with_interfaces_as_a_function_returning_an_incor # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_UnionTypesMustBeArray: + def test_accepts_a_union_type_with_aray_types(self): assert schema_with_field_type(GraphQLUnionType( name='SomeUnion', @@ -499,6 +499,7 @@ def test_rejects_a_union_type_with_incorrectly_typed_types(self): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_UnionTypesMustBeCallableThatReturnsArray: + def test_accepts_a_union_type_with_aray_types(self): assert schema_with_field_type(GraphQLUnionType( name='SomeUnion', @@ -542,6 +543,7 @@ def schema_with_input_object(input_object_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_InputObjectsMustHaveFields: + def test_accepts_an_input_object_type_with_fields(self): assert schema_with_input_object(GraphQLInputObjectType( name='SomeInputObject', @@ -629,6 +631,7 @@ def test_rejects_an_input_object_type_with_a_field_function_that_returns_empty(s # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectTypesMustBeAssertable: + def test_accepts_an_object_type_with_an_is_type_of_function(self): assert schema_with_field_type(GraphQLObjectType( name='AnotherObject', @@ -649,6 +652,7 @@ def test_rejects_an_object_type_with_an_incorrect_type_for_is_type_of(self): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_InterfaceTypesMustBeResolvable: + def test_accepts_an_interface_type_defining_resolve_type(self): AnotherInterfaceType = GraphQLInterfaceType( name='AnotherInterface', @@ -718,6 +722,7 @@ def test_rejects_an_interface_type_not_defining_resolve_type_with_implementing_t # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_UnionTypesMustBeResolvable: + def test_accepts_a_union_type_defining_resolve_type(self): assert schema_with_field_type(GraphQLUnionType( name='SomeUnion', @@ -762,6 +767,7 @@ def test_rejects_a_union_type_not_defining_resolve_type_of_object_types_not_defi # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ScalarTypesMustBeSerializable: + def test_accepts_a_scalar_type_defining_serialize(self): assert schema_with_field_type(GraphQLScalarType( name='SomeScalar', @@ -831,6 +837,7 @@ def test_rejects_a_scalar_type_defining_parse_literal_and_parse_value_with_an_in # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_EnumTypesMustBeWellDefined: + def test_accepts_a_well_defined_enum_type_with_empty_value_definition(self): assert GraphQLEnumType( name='SomeEnum', @@ -919,6 +926,7 @@ def repr_type_as_syntax_safe_fn(_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectFieldsMustHaveOutputTypes: + def accepts(self, type): assert schema_with_object_field_of_type(type) @@ -964,6 +972,7 @@ def schema_with_object_implementing_type(implemented_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectsCanOnlyImplementInterfaces: + def test_accepts_an_object_implementing_an_interface(self): AnotherInterfaceType = GraphQLInterfaceType( name='AnotherInterface', @@ -1007,6 +1016,7 @@ def schema_with_union_of_type(type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_UnionMustRepresentObjectTypes: + def test_accepts_a_union_of_an_object_type(self): assert schema_with_union_of_type(SomeObjectType) @@ -1035,6 +1045,7 @@ def schema_with_interface_field_of_type(field_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_InterfaceFieldsMustHaveOutputTypes: + def accepts(self, type): assert schema_with_interface_field_of_type(type) @@ -1077,6 +1088,7 @@ def schema_with_arg_of_type(arg_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_FieldArgumentsMustHaveInputTypes: + def accepts(self, type): assert schema_with_arg_of_type(type) @@ -1126,6 +1138,7 @@ def schema_with_input_field_of_type(input_field_type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_InputObjectFieldsMustHaveInputTypes: + def accepts(self, type): assert schema_with_input_field_of_type(type) @@ -1171,6 +1184,7 @@ def test_rejects_an_empty_input_field_type(self): class TestTypeSystem_ListMustAcceptGraphQLTypes: + def accepts(self, type): assert GraphQLList(type) @@ -1214,6 +1228,7 @@ def rejects(self, type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_NonNullMustAcceptGraphQLTypes: + def accepts(self, type): assert GraphQLNonNull(type) @@ -1236,6 +1251,7 @@ def rejects(self, type): # noinspection PyMethodMayBeStatic,PyPep8Naming class TestTypeSystem_ObjectsMustAdhereToInterfacesTheyImplement: + def test_accepts_an_object_which_implements_an_interface(self): AnotherInterface = GraphQLInterfaceType( name='AnotherInterface', diff --git a/tests/core_utils/test_ast_from_value.py b/tests/core_utils/test_ast_from_value.py index faa802c9..dd7deca4 100644 --- a/tests/core_utils/test_ast_from_value.py +++ b/tests/core_utils/test_ast_from_value.py @@ -1,9 +1,11 @@ from collections import OrderedDict -from graphql.core.type.definition import GraphQLEnumType, GraphQLEnumValue, GraphQLList, GraphQLInputObjectType, \ - GraphQLInputObjectField + +from graphql.core.language import ast +from graphql.core.type.definition import (GraphQLEnumType, GraphQLEnumValue, + GraphQLInputObjectField, + GraphQLInputObjectType, GraphQLList) from graphql.core.type.scalars import GraphQLFloat from graphql.core.utils.ast_from_value import ast_from_value -from graphql.core.language import ast def test_converts_boolean_values_to_asts(): diff --git a/tests/core_utils/test_ast_to_code.py b/tests/core_utils/test_ast_to_code.py index 4f0bf9bd..37ea399c 100644 --- a/tests/core_utils/test_ast_to_code.py +++ b/tests/core_utils/test_ast_to_code.py @@ -1,6 +1,6 @@ -from graphql.core import parse, Source -from graphql.core.language.parser import Loc +from graphql.core import Source, parse from graphql.core.language import ast +from graphql.core.language.parser import Loc from graphql.core.utils.ast_to_code import ast_to_code from tests.core_language import fixtures diff --git a/tests/core_utils/test_build_ast_schema.py b/tests/core_utils/test_build_ast_schema.py index 0c25df1f..8d9fd308 100644 --- a/tests/core_utils/test_build_ast_schema.py +++ b/tests/core_utils/test_build_ast_schema.py @@ -1,7 +1,8 @@ +from pytest import raises + from graphql.core import parse from graphql.core.utils.build_ast_schema import build_ast_schema from graphql.core.utils.schema_printer import print_schema -from pytest import raises def cycle_output(body, query_type, mutation_type=None, subscription_type=None): diff --git a/tests/core_utils/test_build_client_schema.py b/tests/core_utils/test_build_client_schema.py index 86374fa2..9b9dc9c8 100644 --- a/tests/core_utils/test_build_client_schema.py +++ b/tests/core_utils/test_build_client_schema.py @@ -1,30 +1,20 @@ from collections import OrderedDict + from pytest import raises + from graphql.core import graphql from graphql.core.error import format_error -from graphql.core.type import ( - GraphQLSchema, - GraphQLArgument, - GraphQLScalarType, - GraphQLObjectType, - GraphQLInterfaceType, - GraphQLUnionType, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLField, - GraphQLInputObjectType, - GraphQLInputObjectField, - GraphQLList, - GraphQLNonNull, - GraphQLInt, - GraphQLFloat, - GraphQLString, - GraphQLBoolean, - GraphQLID, -) +from graphql.core.type import (GraphQLArgument, GraphQLBoolean, + GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLFloat, GraphQLID, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString, + GraphQLUnionType) from graphql.core.type.directives import GraphQLDirective -from graphql.core.utils.introspection_query import introspection_query from graphql.core.utils.build_client_schema import build_client_schema +from graphql.core.utils.introspection_query import introspection_query def _test_schema(server_schema): diff --git a/tests/core_utils/test_concat_ast.py b/tests/core_utils/test_concat_ast.py index a590f081..34cfb8f6 100644 --- a/tests/core_utils/test_concat_ast.py +++ b/tests/core_utils/test_concat_ast.py @@ -1,5 +1,4 @@ from graphql.core import Source, parse - from graphql.core.language.printer import print_ast from graphql.core.utils.concat_ast import concat_ast diff --git a/tests/core_utils/test_schema_printer.py b/tests/core_utils/test_schema_printer.py index 3229c15e..29a40558 100644 --- a/tests/core_utils/test_schema_printer.py +++ b/tests/core_utils/test_schema_printer.py @@ -1,20 +1,16 @@ from collections import OrderedDict -from graphql.core.type.definition import GraphQLField, GraphQLArgument, GraphQLInputObjectField, GraphQLEnumValue -from graphql.core.utils.schema_printer import print_schema, print_introspection_schema -from graphql.core.type import ( - GraphQLSchema, - GraphQLInputObjectType, - GraphQLScalarType, - GraphQLObjectType, - GraphQLInterfaceType, - GraphQLUnionType, - GraphQLEnumType, - GraphQLString, - GraphQLInt, - GraphQLBoolean, - GraphQLList, - GraphQLNonNull, -) + +from graphql.core.type import (GraphQLBoolean, GraphQLEnumType, + GraphQLInputObjectType, GraphQLInt, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString, + GraphQLUnionType) +from graphql.core.type.definition import (GraphQLArgument, GraphQLEnumValue, + GraphQLField, + GraphQLInputObjectField) +from graphql.core.utils.schema_printer import (print_introspection_schema, + print_schema) def print_for_test(schema): diff --git a/tests/core_validation/test_arguments_of_correct_type.py b/tests/core_validation/test_arguments_of_correct_type.py index d0ac19a7..80fd6ea8 100644 --- a/tests/core_validation/test_arguments_of_correct_type.py +++ b/tests/core_validation/test_arguments_of_correct_type.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import ArgumentsOfCorrectType -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def bad_value(arg_name, type_name, value, line, column, errors=None): @@ -15,6 +15,7 @@ def bad_value(arg_name, type_name, value, line, column, errors=None): # noinspection PyMethodMayBeStatic class TestValidValues(object): + def test_good_int_value(self): expect_passes_rule(ArgumentsOfCorrectType, ''' { @@ -90,6 +91,7 @@ def test_good_enum_value(self): # noinspection PyMethodMayBeStatic class TestInvalidStringValues(object): + def test_int_into_string(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -137,6 +139,7 @@ def test_unquoted_string_into_string(self): # noinspection PyMethodMayBeStatic class TestInvalidIntValues(object): + def test_string_into_int(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -195,6 +198,7 @@ def test_float_into_int(self): # noinspection PyMethodMayBeStatic class TestInvalidFloatValues(object): + def test_string_into_float(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -231,6 +235,7 @@ def test_unquoted_into_float(self): # noinspection PyMethodMayBeStatic class TestInvalidBooleanValues(object): + def test_int_into_boolean(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -278,6 +283,7 @@ def test_unquoted_into_boolean(self): # noinspection PyMethodMayBeStatic class TestInvalidIDValues(object): + def test_float_into_id(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -314,6 +320,7 @@ def test_unquoted_into_id(self): # noinspection PyMethodMayBeStatic class TestInvalidEnumValues(object): + def test_int_into_enum(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -383,6 +390,7 @@ def test_different_case_enum_value_into_enum(self): # noinspection PyMethodMayBeStatic class TestValidListValues(object): + def test_good_list_value(self): expect_passes_rule(ArgumentsOfCorrectType, ''' { @@ -413,6 +421,7 @@ def test_single_value_into_list(self): # noinspection PyMethodMayBeStatic class TestInvalidListValues(object): + def test_incorrect_item_type(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -440,6 +449,7 @@ def test_single_value_of_incorrect_type(self): # noinspection PyMethodMayBeStatic class TestValidNonNullableValues(object): + def test_arg_on_optional_arg(self): expect_passes_rule(ArgumentsOfCorrectType, ''' { @@ -524,6 +534,7 @@ def test_all_reqs_and_opts_on_mixed_list(self): # noinspection PyMethodMayBeStatic class TestInvalidNonNullableValues(object): + def test_incorrect_value_type(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -550,6 +561,7 @@ def test_incorrect_value_and_missing_argument(self): # noinspection PyMethodMayBeStatic class TestValidInputObjectValue(object): + def test_optional_arg_despite_required_field_in_type(self): expect_passes_rule(ArgumentsOfCorrectType, ''' { @@ -619,6 +631,7 @@ def test_full_object_with_fields_in_different_order(self): # noinspection PyMethodMayBeStatic class TestInvalidInputObjectValue(object): + def test_partial_object_missing_required(self): expect_fails_rule(ArgumentsOfCorrectType, ''' { @@ -669,6 +682,7 @@ def test_partial_object_unknown_field_arg(self): # noinspection PyMethodMayBeStatic class TestDirectiveArguments(object): + def test_with_directives_of_valid_types(self): expect_passes_rule(ArgumentsOfCorrectType, ''' { diff --git a/tests/core_validation/test_default_values_of_correct_type.py b/tests/core_validation/test_default_values_of_correct_type.py index f7cc494f..42a57234 100644 --- a/tests/core_validation/test_default_values_of_correct_type.py +++ b/tests/core_validation/test_default_values_of_correct_type.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import DefaultValuesOfCorrectType -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def default_for_non_null_arg(var_name, type_name, guess_type_name, line, column): diff --git a/tests/core_validation/test_fields_on_correct_type.py b/tests/core_validation/test_fields_on_correct_type.py index 1a731a17..7db819a8 100644 --- a/tests/core_validation/test_fields_on_correct_type.py +++ b/tests/core_validation/test_fields_on_correct_type.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import FieldsOnCorrectType -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def undefined_field(field, type, line, column): diff --git a/tests/core_validation/test_fragments_on_composite_types.py b/tests/core_validation/test_fragments_on_composite_types.py index 42def6c7..c8161182 100644 --- a/tests/core_validation/test_fragments_on_composite_types.py +++ b/tests/core_validation/test_fragments_on_composite_types.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import FragmentsOnCompositeTypes -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def fragment_on_non_composite_error(frag_name, type_name, line, column): diff --git a/tests/core_validation/test_known_argument_names.py b/tests/core_validation/test_known_argument_names.py index 5486ecf1..977a27ae 100644 --- a/tests/core_validation/test_known_argument_names.py +++ b/tests/core_validation/test_known_argument_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import KnownArgumentNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def unknown_arg(arg_name, field_name, type_name, line, column): diff --git a/tests/core_validation/test_known_directives.py b/tests/core_validation/test_known_directives.py index e9cb5ee0..73078d0f 100644 --- a/tests/core_validation/test_known_directives.py +++ b/tests/core_validation/test_known_directives.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import KnownDirectives -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def unknown_directive(directive_name, line, column): diff --git a/tests/core_validation/test_known_fragment_names.py b/tests/core_validation/test_known_fragment_names.py index 02b330d7..2b6cbbf8 100644 --- a/tests/core_validation/test_known_fragment_names.py +++ b/tests/core_validation/test_known_fragment_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import KnownFragmentNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def undefined_fragment(fragment_name, line, column): diff --git a/tests/core_validation/test_known_type_names.py b/tests/core_validation/test_known_type_names.py index 4caed941..277ecae7 100644 --- a/tests/core_validation/test_known_type_names.py +++ b/tests/core_validation/test_known_type_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import KnownTypeNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def unknown_type(type_name, line, column): diff --git a/tests/core_validation/test_lone_anonymous_operation.py b/tests/core_validation/test_lone_anonymous_operation.py index 5f9bbbba..93e13655 100644 --- a/tests/core_validation/test_lone_anonymous_operation.py +++ b/tests/core_validation/test_lone_anonymous_operation.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import LoneAnonymousOperation -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def anon_not_alone(line, column): diff --git a/tests/core_validation/test_no_fragment_cycles.py b/tests/core_validation/test_no_fragment_cycles.py index d7269060..5db450ef 100644 --- a/tests/core_validation/test_no_fragment_cycles.py +++ b/tests/core_validation/test_no_fragment_cycles.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation as L from graphql.core.validation.rules import NoFragmentCycles -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def cycle_error_message(fragment_name, spread_names, *locations): @@ -117,7 +117,8 @@ def test_no_spreading_itself_indirectly_within_inline_fragment(): def test_no_spreading_itself_deeply(): - expect_fails_rule(NoFragmentCycles, ''' + expect_fails_rule( + NoFragmentCycles, ''' fragment fragA on Dog { ...fragB } fragment fragB on Dog { ...fragC } fragment fragC on Dog { ...fragO } @@ -127,9 +128,21 @@ def test_no_spreading_itself_deeply(): fragment fragO on Dog { ...fragP } fragment fragP on Dog { ...fragA, ...fragX } ''', [ - cycle_error_message('fragA', ['fragB', 'fragC', 'fragO', 'fragP'], L(2, 29), L(3, 29), L(4, 29), L(8, 29), L(9, 29)), - cycle_error_message('fragO', ['fragP', 'fragX', 'fragY', 'fragZ'], L(8, 29), L(9, 39), L(5, 29), L(6, 29), L(7, 29)) - ]) + cycle_error_message( + 'fragA', [ + 'fragB', 'fragC', 'fragO', 'fragP'], L( + 2, 29), L( + 3, 29), L( + 4, 29), L( + 8, 29), L( + 9, 29)), cycle_error_message( + 'fragO', [ + 'fragP', 'fragX', 'fragY', 'fragZ'], L( + 8, 29), L( + 9, 39), L( + 5, 29), L( + 6, 29), L( + 7, 29))]) def test_no_spreading_itself_deeply_two_paths(): diff --git a/tests/core_validation/test_no_undefined_variables.py b/tests/core_validation/test_no_undefined_variables.py index 39392b4b..91a3917a 100644 --- a/tests/core_validation/test_no_undefined_variables.py +++ b/tests/core_validation/test_no_undefined_variables.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import NoUndefinedVariables -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def undefined_var(var_name, l1, c1, op_name, l2, c2): diff --git a/tests/core_validation/test_no_unused_fragments.py b/tests/core_validation/test_no_unused_fragments.py index 9a58151f..6c86dd7b 100644 --- a/tests/core_validation/test_no_unused_fragments.py +++ b/tests/core_validation/test_no_unused_fragments.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import NoUnusedFragments -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def unused_fragment(fragment_name, line, column): diff --git a/tests/core_validation/test_no_unused_variables.py b/tests/core_validation/test_no_unused_variables.py index c028e90f..8fa840e8 100644 --- a/tests/core_validation/test_no_unused_variables.py +++ b/tests/core_validation/test_no_unused_variables.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import NoUnusedVariables -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def unused_variable(variable_name, line, column): diff --git a/tests/core_validation/test_overlapping_fields_can_be_merged.py b/tests/core_validation/test_overlapping_fields_can_be_merged.py index 3fb42e8f..013b9b69 100644 --- a/tests/core_validation/test_overlapping_fields_can_be_merged.py +++ b/tests/core_validation/test_overlapping_fields_can_be_merged.py @@ -1,10 +1,12 @@ from graphql.core.language.location import SourceLocation as L -from graphql.core.type.definition import GraphQLObjectType, GraphQLArgument, GraphQLNonNull, GraphQLInterfaceType, \ - GraphQLList, GraphQLField -from graphql.core.type.scalars import GraphQLString, GraphQLInt, GraphQLID +from graphql.core.type.definition import (GraphQLArgument, GraphQLField, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType) +from graphql.core.type.scalars import GraphQLID, GraphQLInt, GraphQLString from graphql.core.type.schema import GraphQLSchema from graphql.core.validation.rules import OverlappingFieldsCanBeMerged -from utils import expect_passes_rule, expect_fails_rule, expect_fails_rule_with_schema, expect_passes_rule_with_schema +from utils import (expect_fails_rule, expect_fails_rule_with_schema, + expect_passes_rule, expect_passes_rule_with_schema) def fields_conflict(reason_name, reason, *locations): @@ -310,7 +312,7 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): NonNullStringBox1Impl = GraphQLObjectType('NonNullStringBox1Impl', { 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)), 'unrelatedField': GraphQLField(GraphQLString) -}, interfaces=[ SomeBox, NonNullStringBox1 ]) +}, interfaces=[SomeBox, NonNullStringBox1]) NonNullStringBox2 = GraphQLInterfaceType('NonNullStringBox2', { 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)) @@ -319,7 +321,7 @@ def test_reports_deep_conflict_to_nearest_common_ancestor(): NonNullStringBox2Impl = GraphQLObjectType('NonNullStringBox2Impl', { 'scalar': GraphQLField(GraphQLNonNull(GraphQLString)), 'unrelatedField': GraphQLField(GraphQLString) -}, interfaces=[ SomeBox, NonNullStringBox2 ]) +}, interfaces=[SomeBox, NonNullStringBox2]) Connection = GraphQLObjectType('Connection', { 'edges': GraphQLField(GraphQLList(GraphQLObjectType('Edge', { diff --git a/tests/core_validation/test_possible_fragment_spreads.py b/tests/core_validation/test_possible_fragment_spreads.py index 61ba7ea0..e5f0d2e9 100644 --- a/tests/core_validation/test_possible_fragment_spreads.py +++ b/tests/core_validation/test_possible_fragment_spreads.py @@ -1,19 +1,21 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import PossibleFragmentSpreads -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule + def error(frag_name, parent_type, frag_type, line, column): - return { - 'message': PossibleFragmentSpreads.type_incompatible_spread_message(frag_name, parent_type, frag_type), - 'locations': [SourceLocation(line, column)] - } + return { + 'message': PossibleFragmentSpreads.type_incompatible_spread_message(frag_name, parent_type, frag_type), + 'locations': [SourceLocation(line, column)] + } def error_anon(parent_type, frag_type, line, column): - return { - 'message': PossibleFragmentSpreads.type_incompatible_anon_spread_message(parent_type, frag_type), - 'locations': [SourceLocation(line, column)] - } + return { + 'message': PossibleFragmentSpreads.type_incompatible_anon_spread_message(parent_type, frag_type), + 'locations': [SourceLocation(line, column)] + } + def test_same_object(): expect_passes_rule(PossibleFragmentSpreads, ''' @@ -21,70 +23,82 @@ def test_same_object(): fragment dogFragment on Dog { barkVolume } ''') + def test_same_object_inline_frag(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment objectWithinObjectAnon on Dog { ... on Dog { barkVolume } } ''') + def test_object_into_implemented_interface(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment objectWithinInterface on Pet { ...dogFragment } fragment dogFragment on Dog { barkVolume } ''') + def test_object_into_containing_union(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment objectWithinUnion on CatOrDog { ...dogFragment } fragment dogFragment on Dog { barkVolume } ''') + def test_union_into_contained_object(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment unionWithinObject on Dog { ...catOrDogFragment } fragment catOrDogFragment on CatOrDog { __typename } ''') + def test_union_into_overlapping_interface(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment unionWithinInterface on Pet { ...catOrDogFragment } fragment catOrDogFragment on CatOrDog { __typename } ''') + def test_union_into_overlapping_union(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment unionWithinUnion on DogOrHuman { ...catOrDogFragment } fragment catOrDogFragment on CatOrDog { __typename } ''') + def test_interface_into_implemented_object(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment interfaceWithinObject on Dog { ...petFragment } fragment petFragment on Pet { name } ''') + def test_interface_into_overlapping_interface(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment interfaceWithinInterface on Pet { ...beingFragment } fragment beingFragment on Being { name } ''') + def test_interface_into_overlapping_interface_in_inline_fragment(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment interfaceWithinInterface on Pet { ... on Being { name } } ''') + def test_interface_into_overlapping_union(): expect_passes_rule(PossibleFragmentSpreads, ''' fragment interfaceWithinUnion on CatOrDog { ...petFragment } fragment petFragment on Pet { name } ''') + def test_different_object_into_object(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidObjectWithinObject on Cat { ...dogFragment } fragment dogFragment on Dog { barkVolume } ''', [error('dogFragment', 'Cat', 'Dog', 2, 51)]) + def test_different_object_into_object_in_inline_fragment(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidObjectWithinObjectAnon on Cat { @@ -92,42 +106,49 @@ def test_different_object_into_object_in_inline_fragment(): } ''', [error_anon('Cat', 'Dog', 3, 9)]) + def test_object_into_not_implementing_interface(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidObjectWithinInterface on Pet { ...humanFragment } fragment humanFragment on Human { pets { name } } ''', [error('humanFragment', 'Pet', 'Human', 2, 54)]) + def test_object_into_not_containing_union(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidObjectWithinUnion on CatOrDog { ...humanFragment } fragment humanFragment on Human { pets { name } } ''', [error('humanFragment', 'CatOrDog', 'Human', 2, 55)]) + def test_union_into_not_contained_object(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidUnionWithinObject on Human { ...catOrDogFragment } fragment catOrDogFragment on CatOrDog { __typename } ''', [error('catOrDogFragment', 'Human', 'CatOrDog', 2, 52)]) + def test_union_into_non_overlapping_interface(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidUnionWithinInterface on Pet { ...humanOrAlienFragment } fragment humanOrAlienFragment on HumanOrAlien { __typename } ''', [error('humanOrAlienFragment', 'Pet', 'HumanOrAlien', 2, 53)]) + def test_union_into_non_overlapping_union(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidUnionWithinUnion on CatOrDog { ...humanOrAlienFragment } fragment humanOrAlienFragment on HumanOrAlien { __typename } ''', [error('humanOrAlienFragment', 'CatOrDog', 'HumanOrAlien', 2, 54)]) + def test_interface_into_non_implementing_object(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidInterfaceWithinObject on Cat { ...intelligentFragment } fragment intelligentFragment on Intelligent { iq } ''', [error('intelligentFragment', 'Cat', 'Intelligent', 2, 54)]) + def test_interface_into_non_overlapping_interface(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidInterfaceWithinInterface on Pet { @@ -136,6 +157,7 @@ def test_interface_into_non_overlapping_interface(): fragment intelligentFragment on Intelligent { iq } ''', [error('intelligentFragment', 'Pet', 'Intelligent', 3, 9)]) + def test_interface_into_non_overlapping_interface_in_inline_fragment(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidInterfaceWithinInterfaceAnon on Pet { @@ -143,6 +165,7 @@ def test_interface_into_non_overlapping_interface_in_inline_fragment(): } ''', [error_anon('Pet', 'Intelligent', 3, 9)]) + def test_interface_into_non_overlapping_union(): expect_fails_rule(PossibleFragmentSpreads, ''' fragment invalidInterfaceWithinUnion on HumanOrAlien { ...petFragment } diff --git a/tests/core_validation/test_provided_non_null_arguments.py b/tests/core_validation/test_provided_non_null_arguments.py index 97f98e28..079e1900 100644 --- a/tests/core_validation/test_provided_non_null_arguments.py +++ b/tests/core_validation/test_provided_non_null_arguments.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import ProvidedNonNullArguments -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def missing_field_arg(field_name, arg_name, type_name, line, column): diff --git a/tests/core_validation/test_scalar_leafs.py b/tests/core_validation/test_scalar_leafs.py index 1283cbc1..073960be 100644 --- a/tests/core_validation/test_scalar_leafs.py +++ b/tests/core_validation/test_scalar_leafs.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import ScalarLeafs -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def no_scalar_subselection(field, type, line, column): diff --git a/tests/core_validation/test_unique_argument_names.py b/tests/core_validation/test_unique_argument_names.py index 7bec3e81..5a8589d6 100644 --- a/tests/core_validation/test_unique_argument_names.py +++ b/tests/core_validation/test_unique_argument_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import UniqueArgumentNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def duplicate_arg(arg_name, l1, c1, l2, c2): @@ -112,7 +112,7 @@ def test_duplicate_directive_arguments(): ''', [ duplicate_arg('arg1', 3, 24, 3, 39) ] - ) + ) def test_many_duplicate_directive_arguments(): diff --git a/tests/core_validation/test_unique_fragment_names.py b/tests/core_validation/test_unique_fragment_names.py index 91d0c03d..75d059cb 100644 --- a/tests/core_validation/test_unique_fragment_names.py +++ b/tests/core_validation/test_unique_fragment_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import UniqueFragmentNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def duplicate_fragment(fragment_name, l1, c1, l2, c2): diff --git a/tests/core_validation/test_unique_input_field_names.py b/tests/core_validation/test_unique_input_field_names.py index 40be1d30..179fd924 100644 --- a/tests/core_validation/test_unique_input_field_names.py +++ b/tests/core_validation/test_unique_input_field_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation as L from graphql.core.validation.rules import UniqueInputFieldNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def duplicate_field(name, l1, l2): diff --git a/tests/core_validation/test_unique_operation_names.py b/tests/core_validation/test_unique_operation_names.py index 50e8705b..3b152616 100644 --- a/tests/core_validation/test_unique_operation_names.py +++ b/tests/core_validation/test_unique_operation_names.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import UniqueOperationNames -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def duplicate_op(op_name, l1, c1, l2, c2): diff --git a/tests/core_validation/test_validation.py b/tests/core_validation/test_validation.py index 5479a094..45844e1e 100644 --- a/tests/core_validation/test_validation.py +++ b/tests/core_validation/test_validation.py @@ -1,5 +1,4 @@ -from graphql.core import parse -from graphql.core import validate +from graphql.core import parse, validate from graphql.core.utils.type_info import TypeInfo from graphql.core.validation import visit_using_rules from graphql.core.validation.rules import specified_rules diff --git a/tests/core_validation/test_variables_are_input_types.py b/tests/core_validation/test_variables_are_input_types.py index b894d2bc..977a8c5f 100644 --- a/tests/core_validation/test_variables_are_input_types.py +++ b/tests/core_validation/test_variables_are_input_types.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import VariablesAreInputTypes -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def non_input_type_on_variable(variable_name, type_name, line, col): diff --git a/tests/core_validation/test_variables_in_allowed_position.py b/tests/core_validation/test_variables_in_allowed_position.py index 56bb119c..7f61d86a 100644 --- a/tests/core_validation/test_variables_in_allowed_position.py +++ b/tests/core_validation/test_variables_in_allowed_position.py @@ -1,6 +1,6 @@ from graphql.core.language.location import SourceLocation from graphql.core.validation.rules import VariablesInAllowedPosition -from utils import expect_passes_rule, expect_fails_rule +from utils import expect_fails_rule, expect_passes_rule def test_boolean_boolean(): @@ -40,6 +40,7 @@ def test_boolean_boolean_in_fragment(): } ''') + def test_non_null_boolean_boolean(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($nonNullBooleanArg: Boolean!) @@ -75,6 +76,7 @@ def test_int_non_null_int_with_default(): } ''') + def test_string_string(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($stringListVar: [String]) @@ -85,6 +87,7 @@ def test_string_string(): } ''') + def test_non_null_string_string(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($stringListVar: [String!]) @@ -95,6 +98,7 @@ def test_non_null_string_string(): } ''') + def test_string_string_item_position(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($stringVar: String) @@ -105,6 +109,7 @@ def test_string_string_item_position(): } ''') + def test_non_null_string_string_item_positiion(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($stringVar: String!) @@ -115,6 +120,7 @@ def test_non_null_string_string_item_positiion(): } ''') + def test_complex_input_complex_input(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($complexVar: ComplexInput) @@ -125,6 +131,7 @@ def test_complex_input_complex_input(): } ''') + def test_complex_input_complex_input_in_field_position(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($boolVar: Boolean = false) @@ -135,6 +142,7 @@ def test_complex_input_complex_input_in_field_position(): } ''') + def test_boolean_non_null_boolean_in_directive(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($boolVar: Boolean!) @@ -143,6 +151,7 @@ def test_boolean_non_null_boolean_in_directive(): } ''') + def test_boolean_non_null_boolean_in_directive_with_default(): expect_passes_rule(VariablesInAllowedPosition, ''' query Query($boolVar: Boolean = false) @@ -151,6 +160,7 @@ def test_boolean_non_null_boolean_in_directive_with_default(): } ''') + def test_int_non_null_int(): expect_fails_rule(VariablesInAllowedPosition, ''' query Query($intArg: Int) { @@ -159,10 +169,11 @@ def test_int_non_null_int(): } } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(4, 45), SourceLocation(2, 19)] } + {'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), + 'locations': [SourceLocation(4, 45), SourceLocation(2, 19)]} ]) + def test_int_non_null_int_within_fragment(): expect_fails_rule(VariablesInAllowedPosition, ''' fragment nonNullIntArgFieldFrag on ComplicatedArgs { @@ -174,10 +185,11 @@ def test_int_non_null_int_within_fragment(): } } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(5, 19), SourceLocation(3, 43)] } + {'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), + 'locations': [SourceLocation(5, 19), SourceLocation(3, 43)]} ]) + def test_int_non_null_int_within_nested_fragment(): expect_fails_rule(VariablesInAllowedPosition, ''' fragment outerFrag on ComplicatedArgs { @@ -192,10 +204,11 @@ def test_int_non_null_int_within_nested_fragment(): } } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), - 'locations': [SourceLocation(8, 19), SourceLocation(6, 43)] } + {'message': VariablesInAllowedPosition.bad_var_pos_message('intArg', 'Int', 'Int!'), + 'locations': [SourceLocation(8, 19), SourceLocation(6, 43)]} ]) + def test_string_over_boolean(): expect_fails_rule(VariablesInAllowedPosition, ''' query Query($stringVar: String) { @@ -204,10 +217,11 @@ def test_string_over_boolean(): } } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean'), - 'locations': [SourceLocation(2, 19), SourceLocation(4, 39)] } + {'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean'), + 'locations': [SourceLocation(2, 19), SourceLocation(4, 39)]} ]) + def test_string_string_fail(): expect_fails_rule(VariablesInAllowedPosition, ''' query Query($stringVar: String) { @@ -216,19 +230,21 @@ def test_string_string_fail(): } } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', '[String]'), - 'locations': [SourceLocation(2, 19), SourceLocation(4, 45)]} + {'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', '[String]'), + 'locations': [SourceLocation(2, 19), SourceLocation(4, 45)]} ]) + def test_boolean_non_null_boolean_in_directive(): expect_fails_rule(VariablesInAllowedPosition, ''' query Query($boolVar: Boolean) { dog @include(if: $boolVar) } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('boolVar', 'Boolean', 'Boolean!'), - 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)] - }]) + {'message': VariablesInAllowedPosition.bad_var_pos_message('boolVar', 'Boolean', 'Boolean!'), + 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)] + }]) + def test_string_non_null_boolean_in_directive(): expect_fails_rule(VariablesInAllowedPosition, ''' @@ -236,6 +252,6 @@ def test_string_non_null_boolean_in_directive(): dog @include(if: $stringVar) } ''', [ - { 'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean!'), - 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)] } + {'message': VariablesInAllowedPosition.bad_var_pos_message('stringVar', 'String', 'Boolean!'), + 'locations': [SourceLocation(2, 19), SourceLocation(3, 26)]} ]) diff --git a/tests/core_validation/utils.py b/tests/core_validation/utils.py index dfb474c9..ac7b9c3b 100644 --- a/tests/core_validation/utils.py +++ b/tests/core_validation/utils.py @@ -1,25 +1,16 @@ -from graphql.core.type.directives import GraphQLDirective, GraphQLIncludeDirective, GraphQLSkipDirective -from graphql.core.validation import validate -from graphql.core.language.parser import parse -from graphql.core.type import ( - GraphQLSchema, - GraphQLObjectType, - GraphQLField, - GraphQLArgument, - GraphQLID, - GraphQLNonNull, - GraphQLString, - GraphQLInt, - GraphQLFloat, - GraphQLBoolean, - GraphQLInterfaceType, - GraphQLEnumType, - GraphQLEnumValue, - GraphQLInputObjectField, - GraphQLInputObjectType, - GraphQLUnionType, - GraphQLList) from graphql.core.error import format_error +from graphql.core.language.parser import parse +from graphql.core.type import (GraphQLArgument, GraphQLBoolean, + GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLFloat, GraphQLID, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLSchema, GraphQLString, GraphQLUnionType) +from graphql.core.type.directives import (GraphQLDirective, + GraphQLIncludeDirective, + GraphQLSkipDirective) +from graphql.core.validation import validate Being = GraphQLInterfaceType('Being', { 'name': GraphQLField(GraphQLString, { @@ -207,7 +198,7 @@ def expect_invalid(schema, rules, query, expected_errors, sort_list=True): error['locations'] = [ {'line': loc.line, 'column': loc.column} for loc in error['locations'] - ] + ] if sort_list: assert sort_lists(list(map(format_error, errors))) == sort_lists(expected_errors)