@@ -152,7 +152,13 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
152
152
# spec in places where it either deviates from or is more strict than
153
153
# NumPy behavior
154
154
155
- def _check_allowed_dtypes (self , other : bool | int | float | Array , dtype_category : str , op : str ) -> Array :
155
+ def _check_allowed_dtypes (
156
+ self ,
157
+ other : bool | int | float | Array ,
158
+ dtype_category : str ,
159
+ op : str ,
160
+ check_promotion : bool = True ,
161
+ ) -> Array :
156
162
"""
157
163
Helper function for operators to only allow specific input dtypes
158
164
@@ -176,7 +182,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
176
182
# This will raise TypeError for type combinations that are not allowed
177
183
# to promote in the spec (even if the NumPy array operator would
178
184
# promote them).
179
- res_dtype = _result_type (self .dtype , other .dtype )
185
+ if check_promotion :
186
+ res_dtype = _result_type (self .dtype , other .dtype )
180
187
if op .startswith ("__i" ):
181
188
# Note: NumPy will allow in-place operators in some cases where
182
189
# the type promoted operator does not match the left-hand side
@@ -604,7 +611,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
604
611
"""
605
612
Performs the operation __ge__.
606
613
"""
607
- other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
614
+ other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" , check_promotion = False )
608
615
if other is NotImplemented :
609
616
return other
610
617
self , other = self ._normalize_two_args (self , other )
@@ -638,7 +645,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
638
645
"""
639
646
Performs the operation __gt__.
640
647
"""
641
- other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
648
+ other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" , check_promotion = False )
642
649
if other is NotImplemented :
643
650
return other
644
651
self , other = self ._normalize_two_args (self , other )
@@ -692,7 +699,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
692
699
"""
693
700
Performs the operation __le__.
694
701
"""
695
- other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
702
+ other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" , check_promotion = False )
696
703
if other is NotImplemented :
697
704
return other
698
705
self , other = self ._normalize_two_args (self , other )
@@ -714,7 +721,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
714
721
"""
715
722
Performs the operation __lt__.
716
723
"""
717
- other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
724
+ other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" , check_promotion = False )
718
725
if other is NotImplemented :
719
726
return other
720
727
self , other = self ._normalize_two_args (self , other )
0 commit comments