File tree Expand file tree Collapse file tree 3 files changed +8
-8
lines changed Expand file tree Collapse file tree 3 files changed +8
-8
lines changed Original file line number Diff line number Diff line change 21
21
import dpctl .tensor as dpt
22
22
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
23
23
24
- from .utils import _map_to_device_dtype , _real_fp_dtypes
24
+ from .utils import _map_to_device_dtype , _no_complex_dtypes , _real_fp_dtypes
25
25
26
26
27
- @pytest .mark .parametrize ("dtype" , _real_fp_dtypes )
27
+ @pytest .mark .parametrize ("dtype" , _no_complex_dtypes )
28
28
def test_cbrt_out_type (dtype ):
29
29
q = get_queue_or_skip ()
30
30
skip_if_dtype_not_supported (dtype , q )
Original file line number Diff line number Diff line change 22
22
import dpctl .tensor as dpt
23
23
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
24
24
25
- from .utils import _compare_dtypes , _real_fp_dtypes
25
+ from .utils import _compare_dtypes , _no_complex_dtypes , _real_fp_dtypes
26
26
27
27
28
- @pytest .mark .parametrize ("op1_dtype" , _real_fp_dtypes )
29
- @pytest .mark .parametrize ("op2_dtype" , _real_fp_dtypes )
28
+ @pytest .mark .parametrize ("op1_dtype" , _no_complex_dtypes )
29
+ @pytest .mark .parametrize ("op2_dtype" , _no_complex_dtypes )
30
30
def test_copysign_dtype_matrix (op1_dtype , op2_dtype ):
31
31
q = get_queue_or_skip ()
32
32
skip_if_dtype_not_supported (op1_dtype , q )
Original file line number Diff line number Diff line change 21
21
import dpctl .tensor as dpt
22
22
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
23
23
24
- from .utils import _map_to_device_dtype , _real_fp_dtypes
24
+ from .utils import _map_to_device_dtype , _no_complex_dtypes , _real_fp_dtypes
25
25
26
26
27
- @pytest .mark .parametrize ("dtype" , _real_fp_dtypes )
27
+ @pytest .mark .parametrize ("dtype" , _no_complex_dtypes )
28
28
def test_rsqrt_out_type (dtype ):
29
29
q = get_queue_or_skip ()
30
30
skip_if_dtype_not_supported (dtype , q )
31
31
32
32
x = dpt .asarray (1 , dtype = dtype , sycl_queue = q )
33
- expected_dtype = np .reciprocal (np .sqrt (1 , dtype = dtype )).dtype
33
+ expected_dtype = np .reciprocal (np .sqrt (np . array ( 1 , dtype = dtype ) )).dtype
34
34
expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
35
35
assert dpt .rsqrt (x ).dtype == expected_dtype
36
36
You can’t perform that action at this time.
0 commit comments