Skip to content

Commit 0208b1f

Browse files
committed
Loose assertion of infinities to very large floats
1 parent 37bfae4 commit 0208b1f

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,25 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
9090
return n
9191

9292

93-
def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
93+
def isclose(
94+
a: float,
95+
b: float,
96+
M: float,
97+
*,
98+
rel_tol: float = 0.25,
99+
abs_tol: float = 1,
100+
) -> bool:
94101
"""Wraps math.isclose with very generous defaults.
95102
96103
This is useful for many floating-point operations where the spec does not
97104
make accuracy requirements.
98105
"""
99-
if not (math.isfinite(a) and math.isfinite(b)):
100-
raise ValueError(f"{a=} and {b=}, but input must be finite")
106+
if math.isnan(a) or math.isnan(b):
107+
raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
108+
if math.isinf(a):
109+
return math.isinf(b) or abs(b) > math.log(M)
110+
elif math.isinf(b):
111+
return math.isinf(a) or abs(a) > math.log(M)
101112
return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
102113

103114

@@ -288,10 +299,10 @@ def unary_assert_against_refimpl(
288299
f"{f_i}={scalar_i}"
289300
)
290301
if res.dtype in dh.complex_dtypes:
291-
assert isclose(scalar_o.real, expected.real), msg
292-
assert isclose(scalar_o.imag, expected.imag), msg
302+
assert isclose(scalar_o.real, expected.real, M), msg
303+
assert isclose(scalar_o.imag, expected.imag, M), msg
293304
else:
294-
assert isclose(scalar_o, expected), msg
305+
assert isclose(scalar_o, expected, M), msg
295306
else:
296307
assert scalar_o == expected, (
297308
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
@@ -364,10 +375,10 @@ def binary_assert_against_refimpl(
364375
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
365376
)
366377
if res.dtype in dh.complex_dtypes:
367-
assert isclose(scalar_o.real, expected.real), msg
368-
assert isclose(scalar_o.imag, expected.imag), msg
378+
assert isclose(scalar_o.real, expected.real, M), msg
379+
assert isclose(scalar_o.imag, expected.imag, M), msg
369380
else:
370-
assert isclose(scalar_o, expected), msg
381+
assert isclose(scalar_o, expected, M), msg
371382
else:
372383
assert scalar_o == expected, (
373384
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
@@ -437,10 +448,10 @@ def right_scalar_assert_against_refimpl(
437448
f"{f_l}={scalar_l}"
438449
)
439450
if res.dtype in dh.complex_dtypes:
440-
assert isclose(scalar_o.real, expected.real), msg
441-
assert isclose(scalar_o.imag, expected.imag), msg
451+
assert isclose(scalar_o.real, expected.real, M), msg
452+
assert isclose(scalar_o.imag, expected.imag, M), msg
442453
else:
443-
assert isclose(scalar_o, expected), msg
454+
assert isclose(scalar_o, expected, M), msg
444455
else:
445456
assert scalar_o == expected, (
446457
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"

0 commit comments

Comments
 (0)