Skip to content

Commit b121d67

Browse files
Adds tests for special FP values for dpt.abs and dpt.sqrt
1 parent 26862b4 commit b121d67

File tree

2 files changed

+122
-14
lines changed

2 files changed

+122
-14
lines changed

dpctl/tests/elementwise/test_abs.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,21 @@
1515
# limitations under the License.
1616

1717
import itertools
18+
import warnings
1819

1920
import numpy as np
2021
import pytest
2122

2223
import dpctl.tensor as dpt
2324
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2425

25-
from .utils import _all_dtypes, _no_complex_dtypes, _usm_types
26+
from .utils import (
27+
_all_dtypes,
28+
_complex_fp_dtypes,
29+
_no_complex_dtypes,
30+
_real_fp_dtypes,
31+
_usm_types,
32+
)
2633

2734

2835
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -135,3 +142,50 @@ def test_abs_out_overlap(dtype):
135142
assert Y is not X
136143
assert np.allclose(dpt.asnumpy(X), Xnp)
137144
assert np.allclose(dpt.asnumpy(Y), Ynp)
145+
146+
147+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
148+
def test_abs_real_fp_special_values(dtype):
149+
q = get_queue_or_skip()
150+
skip_if_dtype_not_supported(dtype, q)
151+
152+
nans_ = [dpt.nan, -dpt.nan]
153+
infs_ = [dpt.inf, -dpt.inf]
154+
finites_ = [-1.0, -0.0, 0.0, 1.0]
155+
inps_ = nans_ + infs_ + finites_
156+
157+
x = dpt.asarray(inps_, dtype=dtype)
158+
r = dpt.abs(x)
159+
160+
with warnings.catch_warnings():
161+
warnings.simplefilter("ignore")
162+
expected_np = np.abs(np.asarray(inps_, dtype=dtype))
163+
164+
expected = dpt.asarray(expected_np, dtype=dtype)
165+
tol = dpt.finfo(r.dtype).resolution
166+
167+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
168+
169+
170+
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
171+
def test_abs_complex_fp_special_values(dtype):
172+
q = get_queue_or_skip()
173+
skip_if_dtype_not_supported(dtype, q)
174+
175+
nans_ = [dpt.nan, -dpt.nan]
176+
infs_ = [dpt.inf, -dpt.inf]
177+
finites_ = [-1.0, -0.0, 0.0, 1.0]
178+
inps_ = nans_ + infs_ + finites_
179+
c_ = [complex(*v) for v in itertools.product(inps_, repeat=2)]
180+
181+
z = dpt.asarray(c_, dtype=dtype)
182+
r = dpt.abs(z)
183+
184+
with warnings.catch_warnings():
185+
warnings.simplefilter("ignore")
186+
expected_np = np.abs(np.asarray(c_, dtype=dtype))
187+
188+
expected = dpt.asarray(expected_np, dtype=dtype)
189+
tol = dpt.finfo(r.dtype).resolution
190+
191+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)

dpctl/tests/elementwise/test_sqrt.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import itertools
18+
import warnings
1819

1920
import numpy as np
2021
import pytest
@@ -23,7 +24,13 @@
2324
import dpctl.tensor as dpt
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

26-
from .utils import _all_dtypes, _map_to_device_dtype, _usm_types
27+
from .utils import (
28+
_all_dtypes,
29+
_complex_fp_dtypes,
30+
_map_to_device_dtype,
31+
_real_fp_dtypes,
32+
_usm_types,
33+
)
2734

2835

2936
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -115,18 +122,6 @@ def test_sqrt_order(dtype):
115122
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
116123

117124

118-
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
119-
def test_sqrt_special_cases():
120-
q = get_queue_or_skip()
121-
122-
X = dpt.asarray(
123-
[dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
124-
)
125-
Xnp = dpt.asnumpy(X)
126-
127-
assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp))
128-
129-
130125
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
131126
def test_sqrt_out_overlap(dtype):
132127
q = get_queue_or_skip()
@@ -149,3 +144,62 @@ def test_sqrt_out_overlap(dtype):
149144
assert Y is not X
150145
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
151146
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
147+
148+
149+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
150+
def test_sqrt_special_cases():
151+
q = get_queue_or_skip()
152+
153+
X = dpt.asarray(
154+
[dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
155+
)
156+
Xnp = dpt.asnumpy(X)
157+
158+
assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp))
159+
160+
161+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
162+
def test_sqrt_real_fp_special_values(dtype):
163+
q = get_queue_or_skip()
164+
skip_if_dtype_not_supported(dtype, q)
165+
166+
nans_ = [dpt.nan, -dpt.nan]
167+
infs_ = [dpt.inf, -dpt.inf]
168+
finites_ = [-1.0, -0.0, 0.0, 1.0]
169+
inps_ = nans_ + infs_ + finites_
170+
171+
x = dpt.asarray(inps_, dtype=dtype)
172+
r = dpt.sqrt(x)
173+
174+
with warnings.catch_warnings():
175+
warnings.simplefilter("ignore")
176+
expected_np = np.sqrt(np.asarray(inps_, dtype=dtype))
177+
178+
expected = dpt.asarray(expected_np, dtype=dtype)
179+
tol = dpt.finfo(r.dtype).resolution
180+
181+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
182+
183+
184+
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
185+
def test_sqrt_complex_fp_special_values(dtype):
186+
q = get_queue_or_skip()
187+
skip_if_dtype_not_supported(dtype, q)
188+
189+
nans_ = [dpt.nan, -dpt.nan]
190+
infs_ = [dpt.inf, -dpt.inf]
191+
finites_ = [-1.0, -0.0, 0.0, 1.0]
192+
inps_ = nans_ + infs_ + finites_
193+
c_ = [complex(*v) for v in itertools.product(inps_, repeat=2)]
194+
195+
z = dpt.asarray(c_, dtype=dtype)
196+
r = dpt.sqrt(z)
197+
198+
with warnings.catch_warnings():
199+
warnings.simplefilter("ignore")
200+
expected_np = np.sqrt(np.asarray(c_, dtype=dtype))
201+
202+
expected = dpt.asarray(expected_np, dtype=dtype)
203+
tol = dpt.finfo(r.dtype).resolution
204+
205+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)

0 commit comments

Comments
 (0)