Skip to content

Commit 26862b4

Browse files
Implements dpctl.tensor.allclose
This utility function is based on symmetric check, unlike numpy.allclose, and verifies that abs(x1-x2) < atol + rtol * max(abs(x1), abs(x2)) This way allclose(x1, x2) is symmetric, and allclose(x1,x2) implies allclose(x2, x1).
1 parent 4a2578f commit 26862b4

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
trunc,
159159
)
160160
from ._reduction import sum
161+
from ._testing import allclose
161162

162163
__all__ = [
163164
"Device",
@@ -301,4 +302,5 @@
301302
"tan",
302303
"tanh",
303304
"trunc",
305+
"allclose",
304306
]

dpctl/tensor/_testing.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
19+
import dpctl.tensor as dpt
20+
import dpctl.utils as du
21+
22+
from ._manipulation_functions import _broadcast_shape_impl
23+
from ._type_utils import _to_device_supported_dtype
24+
25+
26+
def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan):
27+
z1r = dpt.real(z1)
28+
z1i = dpt.imag(z1)
29+
z2r = dpt.real(z2)
30+
z2i = dpt.imag(z2)
31+
if equal_nan:
32+
check1 = dpt.all(dpt.isnan(z1r) == dpt.isnan(z2r)) and dpt.all(
33+
dpt.isnan(z1i) == dpt.isnan(z2i)
34+
)
35+
else:
36+
check1 = (
37+
dpt.logical_not(dpt.any(dpt.isnan(z1r)))
38+
and dpt.logical_not(dpt.any(dpt.isnan(z1i)))
39+
) and (
40+
dpt.logical_not(dpt.any(dpt.isnan(z2r)))
41+
and dpt.logical_not(dpt.any(dpt.isnan(z2i)))
42+
)
43+
if not check1:
44+
return check1
45+
mr = dpt.isinf(z1r)
46+
mi = dpt.isinf(z1i)
47+
check2 = dpt.all(mr == dpt.isinf(z2r)) and dpt.all(mi == dpt.isinf(z2i))
48+
if not check2:
49+
return check2
50+
check3 = dpt.all(z1r[mr] == z2r[mr]) and dpt.all(z1i[mi] == z2i[mi])
51+
if not check3:
52+
return check3
53+
mr = dpt.isfinite(z1r)
54+
mi = dpt.isfinite(z1i)
55+
mv1 = z1r[mr]
56+
mv2 = z2r[mr]
57+
check4 = dpt.all(
58+
dpt.abs(mv1 - mv2)
59+
< atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
60+
)
61+
if not check4:
62+
return check4
63+
mv1 = z1i[mi]
64+
mv2 = z2i[mi]
65+
check5 = dpt.all(
66+
dpt.abs(mv1 - mv2)
67+
< atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
68+
)
69+
return check5
70+
71+
72+
def _allclose_real_fp(r1, r2, atol, rtol, equal_nan):
73+
if equal_nan:
74+
check1 = dpt.all(dpt.isnan(r1) == dpt.isnan(r2))
75+
else:
76+
check1 = dpt.logical_not(dpt.any(dpt.isnan(r1))) and dpt.logical_not(
77+
dpt.any(dpt.isnan(r2))
78+
)
79+
if not check1:
80+
return check1
81+
mr = dpt.isinf(r1)
82+
check2 = dpt.all(mr == dpt.isinf(r2))
83+
if not check2:
84+
return check2
85+
check3 = dpt.all(r1[mr] == r2[mr])
86+
if not check3:
87+
return check3
88+
m = dpt.isfinite(r1)
89+
mv1 = r1[m]
90+
mv2 = r2[m]
91+
check4 = dpt.all(
92+
dpt.abs(mv1 - mv2)
93+
< atol + rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2))
94+
)
95+
return check4
96+
97+
98+
def _allclose_others(r1, r2):
99+
return dpt.all(r1 == r2)
100+
101+
102+
def allclose(a1, a2, atol=1e-5, rtol=1e-8, equal_nan=False):
103+
"""allclose(a1, a2, atol=1e-5, rtol=1e-8)
104+
105+
Returns True if two arrays are element-wise equal within tolerance.
106+
"""
107+
if not isinstance(a1, dpt.usm_ndarray):
108+
raise TypeError(
109+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a1)}."
110+
)
111+
if not isinstance(a2, dpt.usm_ndarray):
112+
raise TypeError(
113+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a2)}."
114+
)
115+
atol = float(atol)
116+
rtol = float(rtol)
117+
equal_nan = bool(equal_nan)
118+
exec_q = du.get_execution_queue(tuple(a.sycl_queue for a in (a1, a2)))
119+
if exec_q is None:
120+
raise du.ExecutionPlacementError(
121+
"Execution placement can not be unambiguously inferred "
122+
"from input arguments."
123+
)
124+
res_sh = _broadcast_shape_impl([a1.shape, a2.shape])
125+
b1 = a1
126+
b2 = a2
127+
if b1.dtype == b2.dtype:
128+
res_dt = b1.dtype
129+
else:
130+
res_dt = np.promote_types(b1.dtype, b2.dtype)
131+
res_dt = _to_device_supported_dtype(res_dt, exec_q.sycl_device)
132+
b1 = dpt.astype(b1, res_dt)
133+
b2 = dpt.astype(b2, res_dt)
134+
135+
b1 = dpt.broadcast_to(b1, res_sh)
136+
b2 = dpt.broadcast_to(b2, res_sh)
137+
138+
k = b1.dtype.kind
139+
if k == "c":
140+
return _allclose_complex_fp(b1, b2, atol, rtol, equal_nan)
141+
elif k == "f":
142+
return _allclose_real_fp(b1, b2, atol, rtol, equal_nan)
143+
else:
144+
return _allclose_others(b1, b2)

dpctl/tests/test_tensor_testing.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import itertools
2+
3+
import pytest
4+
5+
import dpctl.tensor as dpt
6+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
7+
8+
_all_dtypes = [
9+
"?",
10+
"i1",
11+
"u1",
12+
"i2",
13+
"u2",
14+
"i4",
15+
"u4",
16+
"i8",
17+
"u8",
18+
"f2",
19+
"f4",
20+
"f8",
21+
"c8",
22+
"c16",
23+
]
24+
25+
26+
@pytest.mark.parametrize("dtype", _all_dtypes)
27+
def test_allclose(dtype):
28+
q = get_queue_or_skip()
29+
skip_if_dtype_not_supported(dtype, q)
30+
31+
a1 = dpt.ones(10, dtype=dtype)
32+
a2 = dpt.ones(10, dtype=dtype)
33+
34+
assert dpt.allclose(a1, a2)
35+
36+
37+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
38+
def test_allclose_real_fp(dtype):
39+
q = get_queue_or_skip()
40+
skip_if_dtype_not_supported(dtype, q)
41+
42+
v = [dpt.nan, -dpt.nan, dpt.inf, -dpt.inf, -0.0, 0.0, 1.0, -1.0]
43+
a1 = dpt.asarray(v[2:], dtype=dtype)
44+
a2 = dpt.asarray(v[2:], dtype=dtype)
45+
46+
tol = dpt.finfo(a1.dtype).resolution
47+
assert dpt.allclose(a1, a2, atol=tol, rtol=tol)
48+
49+
a1 = dpt.asarray(v, dtype=dtype)
50+
a2 = dpt.asarray(v, dtype=dtype)
51+
52+
assert not dpt.allclose(a1, a2, atol=tol, rtol=tol)
53+
assert dpt.allclose(a1, a2, atol=tol, rtol=tol, equal_nan=True)
54+
55+
56+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
57+
def test_allclose_complex_fp(dtype):
58+
q = get_queue_or_skip()
59+
skip_if_dtype_not_supported(dtype, q)
60+
61+
v = [dpt.nan, -dpt.nan, dpt.inf, -dpt.inf, -0.0, 0.0, 1.0, -1.0]
62+
63+
not_nans = [complex(*xy) for xy in itertools.product(v[2:], repeat=2)]
64+
z1 = dpt.asarray(not_nans, dtype=dtype)
65+
z2 = dpt.asarray(not_nans, dtype=dtype)
66+
67+
tol = dpt.finfo(z1.dtype).resolution
68+
assert dpt.allclose(z1, z2, atol=tol, rtol=tol)
69+
70+
both = [complex(*xy) for xy in itertools.product(v, repeat=2)]
71+
z1 = dpt.asarray(both, dtype=dtype)
72+
z2 = dpt.asarray(both, dtype=dtype)
73+
74+
tol = dpt.finfo(z1.dtype).resolution
75+
assert not dpt.allclose(z1, z2, atol=tol, rtol=tol)
76+
assert dpt.allclose(z1, z2, atol=tol, rtol=tol, equal_nan=True)
77+
78+
79+
def test_allclose_validation():
80+
with pytest.raises(TypeError):
81+
dpt.allclose(True, False)
82+
83+
get_queue_or_skip()
84+
x = dpt.asarray(True)
85+
with pytest.raises(TypeError):
86+
dpt.allclose(x, False)

0 commit comments

Comments
 (0)