Skip to content

Commit 5379bd5

Browse files
committed
Allow any combination of real dtypes in comparisons
This does not change == or != because the standard is currently unclear about that so I'd like to see what happens there first.
1 parent 6b0079b commit 5379bd5

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

array_api_strict/_array_object.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,13 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
152152
# spec in places where it either deviates from or is more strict than
153153
# NumPy behavior
154154

155-
def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
155+
def _check_allowed_dtypes(
156+
self,
157+
other: bool | int | float | Array,
158+
dtype_category: str,
159+
op: str,
160+
check_promotion: bool = True,
161+
) -> Array:
156162
"""
157163
Helper function for operators to only allow specific input dtypes
158164
@@ -176,7 +182,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
176182
# This will raise TypeError for type combinations that are not allowed
177183
# to promote in the spec (even if the NumPy array operator would
178184
# promote them).
179-
res_dtype = _result_type(self.dtype, other.dtype)
185+
if check_promotion:
186+
res_dtype = _result_type(self.dtype, other.dtype)
180187
if op.startswith("__i"):
181188
# Note: NumPy will allow in-place operators in some cases where
182189
# the type promoted operator does not match the left-hand side
@@ -604,7 +611,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
604611
"""
605612
Performs the operation __ge__.
606613
"""
607-
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
614+
other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False)
608615
if other is NotImplemented:
609616
return other
610617
self, other = self._normalize_two_args(self, other)
@@ -638,7 +645,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
638645
"""
639646
Performs the operation __gt__.
640647
"""
641-
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
648+
other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False)
642649
if other is NotImplemented:
643650
return other
644651
self, other = self._normalize_two_args(self, other)
@@ -692,7 +699,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
692699
"""
693700
Performs the operation __le__.
694701
"""
695-
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
702+
other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False)
696703
if other is NotImplemented:
697704
return other
698705
self, other = self._normalize_two_args(self, other)
@@ -714,7 +721,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
714721
"""
715722
Performs the operation __lt__.
716723
"""
717-
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
724+
other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False)
718725
if other is NotImplemented:
719726
return other
720727
self, other = self._normalize_two_args(self, other)

array_api_strict/_elementwise_functions.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,6 @@ def greater(x1: Array, x2: Array, /) -> Array:
439439
"""
440440
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
441441
raise TypeError("Only real numeric dtypes are allowed in greater")
442-
# Call result type here just to raise on disallowed type combinations
443-
_result_type(x1.dtype, x2.dtype)
444442
x1, x2 = Array._normalize_two_args(x1, x2)
445443
return Array._new(np.greater(x1._array, x2._array))
446444

@@ -453,8 +451,6 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
453451
"""
454452
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
455453
raise TypeError("Only real numeric dtypes are allowed in greater_equal")
456-
# Call result type here just to raise on disallowed type combinations
457-
_result_type(x1.dtype, x2.dtype)
458454
x1, x2 = Array._normalize_two_args(x1, x2)
459455
return Array._new(np.greater_equal(x1._array, x2._array))
460456

@@ -524,8 +520,6 @@ def less(x1: Array, x2: Array, /) -> Array:
524520
"""
525521
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
526522
raise TypeError("Only real numeric dtypes are allowed in less")
527-
# Call result type here just to raise on disallowed type combinations
528-
_result_type(x1.dtype, x2.dtype)
529523
x1, x2 = Array._normalize_two_args(x1, x2)
530524
return Array._new(np.less(x1._array, x2._array))
531525

@@ -538,8 +532,6 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
538532
"""
539533
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
540534
raise TypeError("Only real numeric dtypes are allowed in less_equal")
541-
# Call result type here just to raise on disallowed type combinations
542-
_result_type(x1.dtype, x2.dtype)
543535
x1, x2 = Array._normalize_two_args(x1, x2)
544536
return Array._new(np.less_equal(x1._array, x2._array))
545537

0 commit comments

Comments
 (0)