1
1
from ...error import GraphQLError
2
- from ...type .definition import (GraphQLInterfaceType , GraphQLObjectType ,
3
- GraphQLUnionType )
4
2
from ...utils .type_from_ast import type_from_ast
3
+ from ...utils .type_comparators import do_types_overlap
5
4
from .base import ValidationRule
6
5
7
6
@@ -10,7 +9,7 @@ class PossibleFragmentSpreads(ValidationRule):
10
9
def enter_InlineFragment (self , node , key , parent , path , ancestors ):
11
10
frag_type = self .context .get_type ()
12
11
parent_type = self .context .get_parent_type ()
13
- if frag_type and parent_type and not self . do_types_overlap (frag_type , parent_type ):
12
+ if frag_type and parent_type and not do_types_overlap (frag_type , parent_type ):
14
13
self .context .report_error (GraphQLError (
15
14
self .type_incompatible_anon_spread_message (parent_type , frag_type ),
16
15
[node ]
@@ -20,7 +19,7 @@ def enter_FragmentSpread(self, node, key, parent, path, ancestors):
20
19
frag_name = node .name .value
21
20
frag_type = self .get_fragment_type (self .context , frag_name )
22
21
parent_type = self .context .get_parent_type ()
23
- if frag_type and parent_type and not self . do_types_overlap (frag_type , parent_type ):
22
+ if frag_type and parent_type and not do_types_overlap (frag_type , parent_type ):
24
23
self .context .report_error (GraphQLError (
25
24
self .type_incompatible_spread_message (frag_name , parent_type , frag_type ),
26
25
[node ]
@@ -31,21 +30,6 @@ def get_fragment_type(context, name):
31
30
frag = context .get_fragment (name )
32
31
return frag and type_from_ast (context .get_schema (), frag .type_condition )
33
32
34
- @staticmethod
35
- def do_types_overlap (t1 , t2 ):
36
- if t1 == t2 :
37
- return True
38
- if isinstance (t1 , GraphQLObjectType ):
39
- if isinstance (t2 , GraphQLObjectType ):
40
- return False
41
- return t1 in t2 .get_possible_types ()
42
- if isinstance (t1 , GraphQLInterfaceType ) or isinstance (t1 , GraphQLUnionType ):
43
- if isinstance (t2 , GraphQLObjectType ):
44
- return t2 in t1 .get_possible_types ()
45
-
46
- t1_type_names = {possible_type .name : possible_type for possible_type in t1 .get_possible_types ()}
47
- return any (t .name in t1_type_names for t in t2 .get_possible_types ())
48
-
49
33
@staticmethod
50
34
def type_incompatible_spread_message (frag_name , parent_type , frag_type ):
51
35
return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}' .format (frag_name ,
0 commit comments