15
15
# limitations under the License.
16
16
17
17
import itertools
18
+ import warnings
18
19
19
20
import numpy as np
20
21
import pytest
23
24
import dpctl .tensor as dpt
24
25
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
25
26
26
- from .utils import _all_dtypes , _map_to_device_dtype , _usm_types
27
+ from .utils import (
28
+ _all_dtypes ,
29
+ _complex_fp_dtypes ,
30
+ _map_to_device_dtype ,
31
+ _real_fp_dtypes ,
32
+ _usm_types ,
33
+ )
27
34
28
35
29
36
@pytest .mark .parametrize ("dtype" , _all_dtypes )
@@ -115,18 +122,6 @@ def test_sqrt_order(dtype):
115
122
assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
116
123
117
124
118
- @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
119
- def test_sqrt_special_cases ():
120
- q = get_queue_or_skip ()
121
-
122
- X = dpt .asarray (
123
- [dpt .nan , - 1.0 , 0.0 , - 0.0 , dpt .inf , - dpt .inf ], dtype = "f4" , sycl_queue = q
124
- )
125
- Xnp = dpt .asnumpy (X )
126
-
127
- assert_equal (dpt .asnumpy (dpt .sqrt (X )), np .sqrt (Xnp ))
128
-
129
-
130
125
@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
131
126
def test_sqrt_out_overlap (dtype ):
132
127
q = get_queue_or_skip ()
@@ -149,3 +144,62 @@ def test_sqrt_out_overlap(dtype):
149
144
assert Y is not X
150
145
assert_allclose (dpt .asnumpy (X ), Xnp , atol = tol , rtol = tol )
151
146
assert_allclose (dpt .asnumpy (Y ), Ynp , atol = tol , rtol = tol )
147
+
148
+
149
+ @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
150
+ def test_sqrt_special_cases ():
151
+ q = get_queue_or_skip ()
152
+
153
+ X = dpt .asarray (
154
+ [dpt .nan , - 1.0 , 0.0 , - 0.0 , dpt .inf , - dpt .inf ], dtype = "f4" , sycl_queue = q
155
+ )
156
+ Xnp = dpt .asnumpy (X )
157
+
158
+ assert_equal (dpt .asnumpy (dpt .sqrt (X )), np .sqrt (Xnp ))
159
+
160
+
161
+ @pytest .mark .parametrize ("dtype" , _real_fp_dtypes )
162
+ def test_sqrt_real_fp_special_values (dtype ):
163
+ q = get_queue_or_skip ()
164
+ skip_if_dtype_not_supported (dtype , q )
165
+
166
+ nans_ = [dpt .nan , - dpt .nan ]
167
+ infs_ = [dpt .inf , - dpt .inf ]
168
+ finites_ = [- 1.0 , - 0.0 , 0.0 , 1.0 ]
169
+ inps_ = nans_ + infs_ + finites_
170
+
171
+ x = dpt .asarray (inps_ , dtype = dtype )
172
+ r = dpt .sqrt (x )
173
+
174
+ with warnings .catch_warnings ():
175
+ warnings .simplefilter ("ignore" )
176
+ expected_np = np .sqrt (np .asarray (inps_ , dtype = dtype ))
177
+
178
+ expected = dpt .asarray (expected_np , dtype = dtype )
179
+ tol = dpt .finfo (r .dtype ).resolution
180
+
181
+ assert dpt .allclose (r , expected , atol = tol , rtol = tol , equal_nan = True )
182
+
183
+
184
+ @pytest .mark .parametrize ("dtype" , _complex_fp_dtypes )
185
+ def test_sqrt_complex_fp_special_values (dtype ):
186
+ q = get_queue_or_skip ()
187
+ skip_if_dtype_not_supported (dtype , q )
188
+
189
+ nans_ = [dpt .nan , - dpt .nan ]
190
+ infs_ = [dpt .inf , - dpt .inf ]
191
+ finites_ = [- 1.0 , - 0.0 , 0.0 , 1.0 ]
192
+ inps_ = nans_ + infs_ + finites_
193
+ c_ = [complex (* v ) for v in itertools .product (inps_ , repeat = 2 )]
194
+
195
+ z = dpt .asarray (c_ , dtype = dtype )
196
+ r = dpt .sqrt (z )
197
+
198
+ with warnings .catch_warnings ():
199
+ warnings .simplefilter ("ignore" )
200
+ expected_np = np .sqrt (np .asarray (c_ , dtype = dtype ))
201
+
202
+ expected = dpt .asarray (expected_np , dtype = dtype )
203
+ tol = dpt .finfo (r .dtype ).resolution
204
+
205
+ assert dpt .allclose (r , expected , atol = tol , rtol = tol , equal_nan = True )
0 commit comments