Skip to content

Commit 1f34c0d

Browse files
authored
Changes to elementwise function tests that use complex data types (#1412)
* Removed redundant overlap tests Also rewrote overlap test for abs to only use a single dtype * Implements a pytest marker for tests currently broken for complex data types on some platforms The marker "broken_complex" can be used to skip these broken tests. Running the tests with `pytest --runcomplex` permits running the test. * "broken_complex" marker now skips only on Windows * Small change to broken_complex marker description "on Windows" added to make it clear that this marker only skips for Windows platforms
1 parent 34523b1 commit 1f34c0d

File tree

10 files changed

+44
-206
lines changed

10 files changed

+44
-206
lines changed

dpctl/tests/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import sys
2222

23+
import pytest
2324
from _device_attributes_checks import (
2425
check,
2526
device_selector,
@@ -38,3 +39,31 @@
3839
"suppress_invalid_numpy_warnings",
3940
"valid_filter",
4041
]
42+
43+
44+
def pytest_configure(config):
45+
config.addinivalue_line(
46+
"markers",
47+
"broken_complex: Specified again to remove warnings ",
48+
)
49+
50+
51+
def pytest_addoption(parser):
52+
parser.addoption(
53+
"--runcomplex",
54+
action="store_true",
55+
default=False,
56+
help="run broken complex tests on Windows",
57+
)
58+
59+
60+
def pytest_collection_modifyitems(config, items):
61+
if config.getoption("--runcomplex"):
62+
return
63+
skip_complex = pytest.mark.skipif(
64+
os.name == "nt",
65+
reason="need --runcomplex option to run on Windows",
66+
)
67+
for item in items:
68+
if "broken_complex" in item.keywords:
69+
item.add_marker(skip_complex)

dpctl/tests/elementwise/test_abs.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,7 @@
2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2525

26-
from .utils import (
27-
_all_dtypes,
28-
_complex_fp_dtypes,
29-
_no_complex_dtypes,
30-
_real_fp_dtypes,
31-
_usm_types,
32-
)
26+
from .utils import _all_dtypes, _complex_fp_dtypes, _real_fp_dtypes, _usm_types
3327

3428

3529
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -131,26 +125,21 @@ def test_abs_complex(dtype):
131125
)
132126

133127

134-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
135-
def test_abs_out_overlap(dtype):
136-
q = get_queue_or_skip()
137-
skip_if_dtype_not_supported(dtype, q)
138-
139-
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
140-
X = dpt.reshape(X, (3, 5, 4))
141-
142-
Xnp = dpt.asnumpy(X)
143-
Ynp = np.abs(Xnp, out=Xnp)
128+
def test_abs_out_overlap():
129+
get_queue_or_skip()
144130

131+
X = dpt.arange(-3, 3, 1, dtype="i4")
132+
expected = dpt.asarray([3, 2, 1, 0, 1, 2], dtype="i4")
145133
Y = dpt.abs(X, out=X)
134+
146135
assert Y is X
147-
assert np.allclose(dpt.asnumpy(X), Xnp)
136+
assert dpt.all(expected == X)
148137

149-
Ynp = np.abs(Xnp, out=Xnp[::-1])
138+
X = dpt.arange(-3, 3, 1, dtype="i4")
139+
expected = expected[::-1]
150140
Y = dpt.abs(X, out=X[::-1])
151141
assert Y is not X
152-
assert np.allclose(dpt.asnumpy(X), Xnp)
153-
assert np.allclose(dpt.asnumpy(Y), Ynp)
142+
assert dpt.all(expected == X)
154143

155144

156145
@pytest.mark.parametrize("dtype", _real_fp_dtypes)

dpctl/tests/elementwise/test_exp.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -216,26 +216,3 @@ def test_exp_complex_special_cases(dtype):
216216
tol = 8 * dpt.finfo(dtype).resolution
217217
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
218218
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)
219-
220-
221-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
222-
def test_exp_out_overlap(dtype):
223-
q = get_queue_or_skip()
224-
skip_if_dtype_not_supported(dtype, q)
225-
226-
X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
227-
X = dpt.reshape(X, (3, 5))
228-
229-
Xnp = dpt.asnumpy(X)
230-
Ynp = np.exp(Xnp, out=Xnp)
231-
232-
Y = dpt.exp(X, out=X)
233-
tol = 8 * dpt.finfo(Y.dtype).resolution
234-
assert Y is X
235-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
236-
237-
Ynp = np.exp(Xnp, out=Xnp[::-1])
238-
Y = dpt.exp(X, out=X[::-1])
239-
assert Y is not X
240-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
241-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_hyperbolic.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616

1717
import itertools
18-
import os
1918

2019
import numpy as np
2120
import pytest
@@ -271,7 +270,7 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
271270
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
272271

273272

274-
@pytest.mark.skipif(os.name == "nt", reason="Known problems on Windows")
273+
@pytest.mark.broken_complex
275274
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
276275
@pytest.mark.parametrize("dtype", ["c8", "c16"])
277276
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
@@ -294,29 +293,3 @@ def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
294293
assert_allclose(
295294
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
296295
)
297-
298-
299-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
300-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
301-
def test_hyper_out_overlap(np_call, dpt_call, dtype):
302-
q = get_queue_or_skip()
303-
skip_if_dtype_not_supported(dtype, q)
304-
305-
X = dpt.linspace(-np.pi / 2, np.pi / 2, 60, dtype=dtype, sycl_queue=q)
306-
X = dpt.reshape(X, (3, 5, 4))
307-
308-
tol = 8 * dpt.finfo(dtype).resolution
309-
Xnp = dpt.asnumpy(X)
310-
with np.errstate(all="ignore"):
311-
Ynp = np_call(Xnp, out=Xnp)
312-
313-
Y = dpt_call(X, out=X)
314-
assert Y is X
315-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
316-
317-
with np.errstate(all="ignore"):
318-
Ynp = np_call(Xnp, out=Xnp[::-1])
319-
Y = dpt_call(X, out=X[::-1])
320-
assert Y is not X
321-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
322-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_log.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,27 +128,3 @@ def test_log_special_cases():
128128
)
129129

130130
assert_equal(dpt.asnumpy(Y), expected)
131-
132-
133-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
134-
def test_log_out_overlap(dtype):
135-
q = get_queue_or_skip()
136-
skip_if_dtype_not_supported(dtype, q)
137-
138-
X = dpt.linspace(5, 35, 60, dtype=dtype, sycl_queue=q)
139-
X = dpt.reshape(X, (3, 5, 4))
140-
141-
Xnp = dpt.asnumpy(X)
142-
Ynp = np.log(Xnp, out=Xnp)
143-
144-
Y = dpt.log(X, out=X)
145-
assert Y is X
146-
147-
tol = 8 * dpt.finfo(Y.dtype).resolution
148-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
149-
150-
Ynp = np.log(Xnp, out=Xnp[::-1])
151-
Y = dpt.log(X, out=X[::-1])
152-
assert Y is not X
153-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
154-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_round.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -213,26 +213,3 @@ def test_round_complex_special_cases(dtype):
213213
tol = 8 * dpt.finfo(dtype).resolution
214214
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
215215
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)
216-
217-
218-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
219-
def test_round_out_overlap(dtype):
220-
q = get_queue_or_skip()
221-
skip_if_dtype_not_supported(dtype, q)
222-
223-
X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
224-
X = dpt.reshape(X, (3, 5))
225-
226-
Xnp = dpt.asnumpy(X)
227-
Ynp = np.round(Xnp, out=Xnp)
228-
229-
Y = dpt.round(X, out=X)
230-
tol = 8 * dpt.finfo(Y.dtype).resolution
231-
assert Y is X
232-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
233-
234-
Ynp = np.round(Xnp, out=Xnp[::-1])
235-
Y = dpt.round(X, out=X[::-1])
236-
assert Y is not X
237-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
238-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_sqrt.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,30 +122,6 @@ def test_sqrt_order(dtype):
122122
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
123123

124124

125-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
126-
def test_sqrt_out_overlap(dtype):
127-
q = get_queue_or_skip()
128-
skip_if_dtype_not_supported(dtype, q)
129-
130-
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
131-
X = dpt.reshape(X, (3, 5, 4))
132-
133-
Xnp = dpt.asnumpy(X)
134-
Ynp = np.sqrt(Xnp, out=Xnp)
135-
136-
Y = dpt.sqrt(X, out=X)
137-
assert Y is X
138-
139-
tol = 8 * dpt.finfo(Y.dtype).resolution
140-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
141-
142-
Ynp = np.sqrt(Xnp, out=Xnp[::-1])
143-
Y = dpt.sqrt(X, out=X[::-1])
144-
assert Y is not X
145-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
146-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
147-
148-
149125
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
150126
def test_sqrt_special_cases():
151127
q = get_queue_or_skip()

dpctl/tests/elementwise/test_square.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -97,29 +97,3 @@ def test_square_special_cases(dtype):
9797
rtol=tol,
9898
equal_nan=True,
9999
)
100-
101-
102-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
103-
def test_square_out_overlap(dtype):
104-
q = get_queue_or_skip()
105-
skip_if_dtype_not_supported(dtype, q)
106-
107-
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
108-
X = dpt.reshape(X, (3, 5, 4))
109-
110-
Xnp = dpt.asnumpy(X)
111-
Ynp = np.square(Xnp, out=Xnp)
112-
113-
Y = dpt.square(X, out=X)
114-
assert Y is X
115-
assert np.allclose(dpt.asnumpy(X), Xnp)
116-
117-
X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
118-
X = dpt.reshape(X, (3, 5, 4))
119-
Xnp = dpt.asnumpy(X)
120-
121-
Ynp = np.square(Xnp, out=Xnp[::-1])
122-
Y = dpt.square(X, out=X[::-1])
123-
assert Y is not X
124-
assert np.allclose(dpt.asnumpy(X), Xnp)
125-
assert np.allclose(dpt.asnumpy(Y), Ynp)

dpctl/tests/elementwise/test_trigonometric.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616

1717
import itertools
18-
import os
1918

2019
import numpy as np
2120
import pytest
@@ -268,7 +267,7 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
268267
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
269268

270269

271-
@pytest.mark.skipif(os.name == "nt", reason="Known problem on Windows")
270+
@pytest.mark.broken_complex
272271
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
273272
@pytest.mark.parametrize("dtype", ["c8", "c16"])
274273
def test_trig_complex_special_cases(np_call, dpt_call, dtype):
@@ -291,38 +290,3 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype):
291290
assert_allclose(
292291
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
293292
)
294-
295-
296-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
297-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
298-
def test_trig_out_overlap(np_call, dpt_call, dtype):
299-
q = get_queue_or_skip()
300-
skip_if_dtype_not_supported(dtype, q)
301-
302-
if os.name == "nt" and dpt.isdtype(dpt.dtype(dtype), "complex floating"):
303-
pytest.skip("Know problems on Windows")
304-
305-
if np_call == np.tan:
306-
X = dpt.linspace(-np.pi / 2, np.pi / 2, 64, dtype=dtype, sycl_queue=q)[
307-
2:-2
308-
]
309-
tol = 50 * dpt.finfo(dtype).resolution
310-
else:
311-
X = dpt.linspace(-np.pi / 2, np.pi / 2, 60, dtype=dtype, sycl_queue=q)
312-
tol = 8 * dpt.finfo(dtype).resolution
313-
X = dpt.reshape(X, (3, 5, 4))
314-
315-
Xnp = dpt.asnumpy(X)
316-
with np.errstate(all="ignore"):
317-
Ynp = np_call(Xnp, out=Xnp)
318-
319-
Y = dpt_call(X, out=X)
320-
assert Y is X
321-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
322-
323-
with np.errstate(all="ignore"):
324-
Ynp = np_call(Xnp, out=Xnp[::-1])
325-
Y = dpt_call(X, out=X[::-1])
326-
assert Y is not X
327-
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
328-
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ omit = [
3333
]
3434

3535
[tool.pytest.ini.options]
36+
markers = [
37+
"broken_complex: mark a test that is skipped on Windows due to complex implementation",
38+
]
3639
minversion = "6.0"
3740
norecursedirs= [
3841
".*", "*.egg*", "build", "dist", "conda-recipe",

0 commit comments

Comments
 (0)