Skip to content

Commit 64eb98c

Browse files
committed
Adds tests for cbrt, copysign, exp2
1 parent 9ee16d6 commit 64eb98c

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

dpctl/tests/elementwise/test_cbrt.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
import pytest
19+
from numpy.testing import assert_allclose
20+
21+
import dpctl.tensor as dpt
22+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
23+
24+
from .utils import _map_to_device_dtype, _real_fp_dtypes
25+
26+
27+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
28+
def test_cbrt_out_type(dtype):
29+
q = get_queue_or_skip()
30+
skip_if_dtype_not_supported(dtype, q)
31+
32+
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
33+
expected_dtype = np.cbrt(np.array(0, dtype=dtype)).dtype
34+
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
35+
assert dpt.cbrt(X).dtype == expected_dtype
36+
37+
38+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
39+
def test_cbrt_output_contig(dtype):
40+
q = get_queue_or_skip()
41+
skip_if_dtype_not_supported(dtype, q)
42+
43+
n_seq = 1027
44+
45+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)
46+
Xnp = dpt.asnumpy(X)
47+
48+
Y = dpt.cbrt(X)
49+
tol = 8 * dpt.finfo(Y.dtype).resolution
50+
51+
assert_allclose(dpt.asnumpy(Y), np.cbrt(Xnp), atol=tol, rtol=tol)
52+
53+
54+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
55+
def test_cbrt_output_strided(dtype):
56+
q = get_queue_or_skip()
57+
skip_if_dtype_not_supported(dtype, q)
58+
59+
n_seq = 2054
60+
61+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
62+
Xnp = dpt.asnumpy(X)
63+
64+
Y = dpt.cbrt(X)
65+
tol = 8 * dpt.finfo(Y.dtype).resolution
66+
67+
assert_allclose(dpt.asnumpy(Y), np.cbrt(Xnp), atol=tol, rtol=tol)
68+
69+
70+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
71+
def test_cbrt_special_cases():
72+
get_queue_or_skip()
73+
74+
X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
75+
res = dpt.cbrt(X)
76+
expected = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
77+
tol = dpt.finfo(dpt.float32).resolution
78+
79+
assert dpt.allclose(res, expected, atol=tol, rtol=tol, equal_nan=True)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
24+
25+
from .utils import _compare_dtypes, _real_fp_dtypes
26+
27+
28+
@pytest.mark.parametrize("op1_dtype", _real_fp_dtypes)
29+
@pytest.mark.parametrize("op2_dtype", _real_fp_dtypes)
30+
def test_copysign_dtype_matrix(op1_dtype, op2_dtype):
31+
q = get_queue_or_skip()
32+
skip_if_dtype_not_supported(op1_dtype, q)
33+
skip_if_dtype_not_supported(op2_dtype, q)
34+
35+
sz = 127
36+
ar1 = dpt.ones(sz, dtype=op1_dtype)
37+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
38+
39+
r = dpt.copysign(ar1, ar2)
40+
assert isinstance(r, dpt.usm_ndarray)
41+
expected = np.copysign(
42+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
43+
)
44+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
45+
assert r.shape == ar1.shape
46+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
47+
assert r.sycl_queue == ar1.sycl_queue
48+
49+
ar3 = dpt.ones(sz, dtype=op1_dtype)
50+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
51+
52+
r = dpt.copysign(ar3[::-1], ar4[::2])
53+
assert isinstance(r, dpt.usm_ndarray)
54+
expected = np.copysign(
55+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
56+
)
57+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
58+
assert r.shape == ar3.shape
59+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
60+
61+
62+
@pytest.mark.parametrize("arr_dt", _real_fp_dtypes)
63+
def test_copysign_python_scalar(arr_dt):
64+
q = get_queue_or_skip()
65+
skip_if_dtype_not_supported(arr_dt, q)
66+
67+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
68+
py_ones = (
69+
bool(1),
70+
int(1),
71+
float(1),
72+
np.float32(1),
73+
ctypes.c_int(1),
74+
)
75+
for sc in py_ones:
76+
R = dpt.copysign(X, sc)
77+
assert isinstance(R, dpt.usm_ndarray)
78+
R = dpt.copysign(sc, X)
79+
assert isinstance(R, dpt.usm_ndarray)
80+
81+
82+
@pytest.mark.parametrize("dt", _real_fp_dtypes)
83+
def test_copysign(dt):
84+
q = get_queue_or_skip()
85+
skip_if_dtype_not_supported(dt, q)
86+
87+
x = dpt.arange(100, dtype=dt, sycl_queue=q)
88+
x[1::2] *= -1
89+
y = dpt.ones(100, dtype=dt, sycl_queue=q)
90+
y[::2] *= -1
91+
res = dpt.copysign(x, y)
92+
expected = dpt.negative(x)
93+
tol = dpt.finfo(dt).resolution
94+
assert dpt.allclose(res, expected, atol=tol, rtol=tol)
95+
96+
97+
def test_copysign_special_values():
98+
get_queue_or_skip()
99+
100+
x1 = dpt.asarray([1.0, 0.0, dpt.nan, dpt.nan], dtype="f4")
101+
y1 = dpt.asarray([-1.0, -0.0, -dpt.nan, -1], dtype="f4")
102+
res = dpt.copysign(x1, y1)
103+
assert dpt.all(dpt.signbit(res))
104+
x2 = dpt.asarray([-1.0, -0.0, -dpt.nan, -dpt.nan], dtype="f4")
105+
res = dpt.copysign(x2, y1)
106+
assert dpt.all(dpt.signbit(res))
107+
y2 = dpt.asarray([0.0, 1.0, dpt.nan, 1.0], dtype="f4")
108+
res = dpt.copysign(x2, y2)
109+
assert not dpt.any(dpt.signbit(res))
110+
res = dpt.copysign(x1, y2)
111+
assert not dpt.any(dpt.signbit(res))

dpctl/tests/elementwise/test_exp2.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
19+
import numpy as np
20+
import pytest
21+
from numpy.testing import assert_allclose
22+
23+
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _map_to_device_dtype, _usm_types
27+
28+
29+
@pytest.mark.parametrize("dtype", _all_dtypes)
30+
def test_exp2_out_type(dtype):
31+
q = get_queue_or_skip()
32+
skip_if_dtype_not_supported(dtype, q)
33+
34+
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
35+
expected_dtype = np.exp2(np.array(0, dtype=dtype)).dtype
36+
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
37+
assert dpt.exp2(X).dtype == expected_dtype
38+
39+
40+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
41+
def test_exp2_output_contig(dtype):
42+
q = get_queue_or_skip()
43+
skip_if_dtype_not_supported(dtype, q)
44+
45+
n_seq = 1027
46+
47+
X = dpt.linspace(1, 5, num=n_seq, dtype=dtype, sycl_queue=q)
48+
Xnp = dpt.asnumpy(X)
49+
50+
Y = dpt.exp2(X)
51+
tol = 8 * dpt.finfo(Y.dtype).resolution
52+
53+
assert_allclose(dpt.asnumpy(Y), np.exp2(Xnp), atol=tol, rtol=tol)
54+
55+
56+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
57+
def test_exp2_output_strided(dtype):
58+
q = get_queue_or_skip()
59+
skip_if_dtype_not_supported(dtype, q)
60+
61+
n_seq = 2 * 1027
62+
63+
X = dpt.linspace(1, 5, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
64+
Xnp = dpt.asnumpy(X)
65+
66+
Y = dpt.exp2(X)
67+
tol = 8 * dpt.finfo(Y.dtype).resolution
68+
69+
assert_allclose(dpt.asnumpy(Y), np.exp2(Xnp), atol=tol, rtol=tol)
70+
71+
72+
@pytest.mark.parametrize("usm_type", _usm_types)
73+
def test_exp2_usm_type(usm_type):
74+
q = get_queue_or_skip()
75+
76+
arg_dt = np.dtype("f4")
77+
input_shape = (10, 10, 10, 10)
78+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
79+
X[..., 0::2] = 1 / 4
80+
X[..., 1::2] = 1 / 2
81+
82+
Y = dpt.exp2(X)
83+
assert Y.usm_type == X.usm_type
84+
assert Y.sycl_queue == X.sycl_queue
85+
assert Y.flags.c_contiguous
86+
87+
expected_Y = np.empty(input_shape, dtype=arg_dt)
88+
expected_Y[..., 0::2] = np.exp2(np.float32(1 / 4))
89+
expected_Y[..., 1::2] = np.exp2(np.float32(1 / 2))
90+
tol = 8 * dpt.finfo(Y.dtype).resolution
91+
92+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
93+
94+
95+
@pytest.mark.parametrize("dtype", _all_dtypes)
96+
def test_exp2_order(dtype):
97+
q = get_queue_or_skip()
98+
skip_if_dtype_not_supported(dtype, q)
99+
100+
arg_dt = np.dtype(dtype)
101+
input_shape = (10, 10, 10, 10)
102+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
103+
X[..., 0::2] = 1 / 4
104+
X[..., 1::2] = 1 / 2
105+
106+
for ord in ["C", "F", "A", "K"]:
107+
for perms in itertools.permutations(range(4)):
108+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
109+
Y = dpt.exp2(U, order=ord)
110+
expected_Y = np.exp2(dpt.asnumpy(U))
111+
tol = 8 * max(
112+
dpt.finfo(Y.dtype).resolution,
113+
np.finfo(expected_Y.dtype).resolution,
114+
)
115+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
116+
117+
118+
def test_exp2_special_cases():
119+
get_queue_or_skip()
120+
121+
X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
122+
res = np.asarray([np.nan, 1.0, 1.0, np.inf, 0.0], dtype="f4")
123+
124+
tol = dpt.finfo(X.dtype).resolution
125+
assert_allclose(dpt.asnumpy(dpt.exp2(X)), res, atol=tol, rtol=tol)
126+
127+
# special cases for complex variant
128+
num_finite = 1.0
129+
vals = [
130+
complex(0.0, 0.0),
131+
complex(num_finite, dpt.inf),
132+
complex(num_finite, dpt.nan),
133+
complex(dpt.inf, 0.0),
134+
complex(-dpt.inf, num_finite),
135+
complex(dpt.inf, num_finite),
136+
complex(-dpt.inf, dpt.inf),
137+
complex(dpt.inf, dpt.inf),
138+
complex(-dpt.inf, dpt.nan),
139+
complex(dpt.inf, dpt.nan),
140+
complex(dpt.nan, 0.0),
141+
complex(dpt.nan, num_finite),
142+
complex(dpt.nan, dpt.nan),
143+
]
144+
X = dpt.asarray(vals, dtype=dpt.complex64)
145+
cis_1 = complex(np.cos(num_finite), np.sin(num_finite))
146+
c_nan = complex(np.nan, np.nan)
147+
res = np.asarray(
148+
[
149+
complex(1.0, 0.0),
150+
c_nan,
151+
c_nan,
152+
complex(np.inf, 0.0),
153+
0.0,
154+
np.inf * cis_1,
155+
complex(0.0, 0.0),
156+
complex(np.inf, np.nan),
157+
complex(0.0, 0.0),
158+
complex(np.inf, np.nan),
159+
complex(np.nan, 0.0),
160+
c_nan,
161+
c_nan,
162+
],
163+
dtype=np.complex64,
164+
)
165+
166+
tol = dpt.finfo(X.dtype).resolution
167+
with np.errstate(invalid="ignore"):
168+
assert_allclose(dpt.asnumpy(dpt.exp2(X)), res, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)