diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 0af265e2..384a6db1 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -335,7 +335,13 @@ def isclose( atol = int(atol) if rtol == 0: return xp.abs(a - b) <= atol - nrtol = int(1.0 / rtol) + + try: + nrtol = xp.asarray(int(1.0 / rtol), dtype=b.dtype) + except OverflowError: + # rtol * max_int(dtype) < 1, so it's inconsequential + return xp.abs(a - b) <= atol + return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ef1a1fc2..eee65145 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -354,6 +354,13 @@ def test_tolerance(self, dtype: str, xp: ModuleType): xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False])) xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False])) + @pytest.mark.parametrize("dtype", ["int8", "uint8"]) + def test_tolerance_integer_overflow(self, dtype: str, xp: ModuleType): + """1/rtol is too large for dtype""" + a = xp.asarray([100, 100], dtype=getattr(xp, dtype)) + b = xp.asarray([100, 101], dtype=getattr(xp, dtype)) + xp_assert_equal(isclose(a, b), xp.asarray([True, False])) + def test_very_small_numbers(self, xp: ModuleType): a = xp.asarray([1e-9, 1e-9]) b = xp.asarray([1.0001e-9, 1.00001e-9])