Skip to content

Commit 3e08194

Browse files
committed
Modified tests for cbrt, copysign, and rsqrt
Now test more type combinations/output types
1 parent 0695fde commit 3e08194

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

dpctl/tests/elementwise/test_cbrt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
import dpctl.tensor as dpt
2222
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2323

24-
from .utils import _map_to_device_dtype, _real_fp_dtypes
24+
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_fp_dtypes
2525

2626

27-
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
27+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
2828
def test_cbrt_out_type(dtype):
2929
q = get_queue_or_skip()
3030
skip_if_dtype_not_supported(dtype, q)

dpctl/tests/elementwise/test_copysign.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import dpctl.tensor as dpt
2323
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

25-
from .utils import _compare_dtypes, _real_fp_dtypes
25+
from .utils import _compare_dtypes, _no_complex_dtypes, _real_fp_dtypes
2626

2727

28-
@pytest.mark.parametrize("op1_dtype", _real_fp_dtypes)
29-
@pytest.mark.parametrize("op2_dtype", _real_fp_dtypes)
28+
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes)
29+
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes)
3030
def test_copysign_dtype_matrix(op1_dtype, op2_dtype):
3131
q = get_queue_or_skip()
3232
skip_if_dtype_not_supported(op1_dtype, q)

dpctl/tests/elementwise/test_rsqrt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
import dpctl.tensor as dpt
2222
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2323

24-
from .utils import _map_to_device_dtype, _real_fp_dtypes
24+
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_fp_dtypes
2525

2626

27-
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
27+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
2828
def test_rsqrt_out_type(dtype):
2929
q = get_queue_or_skip()
3030
skip_if_dtype_not_supported(dtype, q)
3131

3232
x = dpt.asarray(1, dtype=dtype, sycl_queue=q)
33-
expected_dtype = np.reciprocal(np.sqrt(1, dtype=dtype)).dtype
33+
expected_dtype = np.reciprocal(np.sqrt(np.array(1, dtype=dtype))).dtype
3434
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
3535
assert dpt.rsqrt(x).dtype == expected_dtype
3636

0 commit comments

Comments
 (0)