Skip to content

Commit a75fff8

Browse files
tensor.allclose to use abs(a-b) < max(atol, rtol*max(abs(a), abs(b)))
1 parent f36af57 commit a75fff8

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

dpctl/tensor/_testing.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan):
5656
mv2 = z2r[mr]
5757
check4 = dpt.all(
5858
dpt.abs(mv1 - mv2)
59-
< atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
59+
< dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
6060
)
6161
if not check4:
6262
return check4
6363
mv1 = z1i[mi]
6464
mv2 = z2i[mi]
6565
check5 = dpt.all(
6666
dpt.abs(mv1 - mv2)
67-
<= atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
67+
<= dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
6868
)
6969
return check5
7070

@@ -90,7 +90,7 @@ def _allclose_real_fp(r1, r2, atol, rtol, equal_nan):
9090
mv2 = r2[m]
9191
check4 = dpt.all(
9292
dpt.abs(mv1 - mv2)
93-
<= atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
93+
<= dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
9494
)
9595
return check4
9696

@@ -103,6 +103,10 @@ def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
103103
"""allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False)
104104
105105
Returns True if two arrays are element-wise equal within tolerances.
106+
107+
The testing is based on the following elementwise comparison:
108+
109+
abs(a - b) <= max(atol, rtol * max(abs(a), abs(b)))
106110
"""
107111
if not isinstance(a1, dpt.usm_ndarray):
108112
raise TypeError(
@@ -114,6 +118,10 @@ def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
114118
)
115119
atol = float(atol)
116120
rtol = float(rtol)
121+
if atol < 0.0 or rtol < 0.0:
122+
raise ValueError(
123+
"Absolute and relative tolerances must be non-negative"
124+
)
117125
equal_nan = bool(equal_nan)
118126
exec_q = du.get_execution_queue(tuple(a.sycl_queue for a in (a1, a2)))
119127
if exec_q is None:

0 commit comments

Comments
 (0)