@@ -56,15 +56,15 @@ def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan):
56
56
mv2 = z2r [mr ]
57
57
check4 = dpt .all (
58
58
dpt .abs (mv1 - mv2 )
59
- < atol + rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ))
59
+ < dpt . maximum ( atol , rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ) ))
60
60
)
61
61
if not check4 :
62
62
return check4
63
63
mv1 = z1i [mi ]
64
64
mv2 = z2i [mi ]
65
65
check5 = dpt .all (
66
66
dpt .abs (mv1 - mv2 )
67
- <= atol + rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ))
67
+ <= dpt . maximum ( atol , rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ) ))
68
68
)
69
69
return check5
70
70
@@ -90,7 +90,7 @@ def _allclose_real_fp(r1, r2, atol, rtol, equal_nan):
90
90
mv2 = r2 [m ]
91
91
check4 = dpt .all (
92
92
dpt .abs (mv1 - mv2 )
93
- <= atol + rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ))
93
+ <= dpt . maximum ( atol , rtol * dpt .maximum (dpt .abs (mv1 ), dpt .abs (mv2 ) ))
94
94
)
95
95
return check4
96
96
@@ -103,6 +103,10 @@ def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
103
103
"""allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False)
104
104
105
105
Returns True if two arrays are element-wise equal within tolerances.
106
+
107
+ The testing is based on the following elementwise comparison:
108
+
109
+ abs(a - b) <= max(atol, rtol * max(abs(a), abs(b)))
106
110
"""
107
111
if not isinstance (a1 , dpt .usm_ndarray ):
108
112
raise TypeError (
@@ -114,6 +118,10 @@ def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
114
118
)
115
119
atol = float (atol )
116
120
rtol = float (rtol )
121
+ if atol < 0.0 or rtol < 0.0 :
122
+ raise ValueError (
123
+ "Absolute and relative tolerances must be non-negative"
124
+ )
117
125
equal_nan = bool (equal_nan )
118
126
exec_q = du .get_execution_queue (tuple (a .sycl_queue for a in (a1 , a2 )))
119
127
if exec_q is None :
0 commit comments