Skip to content

Commit f58d19d

Browse files
npolina4antonwolfy
andauthored
Implemented dpnp.can_cast function (#1600)
* Implemented dpnp.can_cast function * address comments * Update tests/third_party/cupy/test_type_routines.py Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> * Update tests/third_party/cupy/test_type_routines.py Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent 7bbbf1a commit f58d19d

File tree

3 files changed

+156
-1
lines changed

3 files changed

+156
-1
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"atleast_3d",
5757
"broadcast_arrays",
5858
"broadcast_to",
59+
"can_cast",
5960
"concatenate",
6061
"copyto",
6162
"expand_dims",
@@ -402,6 +403,47 @@ def broadcast_to(array, /, shape, subok=False):
402403
return dpnp_array._create_from_usm_ndarray(new_array)
403404

404405

406+
def can_cast(from_, to, casting="safe"):
407+
"""
408+
Returns ``True`` if cast between data types can occur according to the casting rule.
409+
410+
If `from` is a scalar or array scalar, also returns ``True`` if the scalar value can
411+
be cast without overflow or truncation to an integer.
412+
413+
For full documentation refer to :obj:`numpy.can_cast`.
414+
415+
Parameters
416+
----------
417+
from : dpnp.array, dtype
418+
Source data type.
419+
to : dtype
420+
Target data type.
421+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
422+
Controls what kind of data casting may occur.
423+
424+
Returns
425+
-------
426+
out: bool
427+
True if cast can occur according to the casting rule.
428+
429+
See Also
430+
--------
431+
:obj:`dpnp.result_type` : Returns the type that results from applying the NumPy
432+
type promotion rules to the arguments.
433+
434+
"""
435+
436+
if dpnp.is_supported_array_type(to):
437+
raise TypeError("Cannot construct a dtype from an array")
438+
439+
dtype_from = (
440+
from_.dtype
441+
if dpnp.is_supported_array_type(from_)
442+
else dpnp.dtype(from_)
443+
)
444+
return dpt.can_cast(dtype_from, to, casting)
445+
446+
405447
def concatenate(
406448
arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"
407449
):
@@ -519,7 +561,7 @@ def copyto(dst, src, casting="same_kind", where=True):
519561
elif not dpnp.is_supported_array_type(src):
520562
src = dpnp.array(src, sycl_queue=dst.sycl_queue)
521563

522-
if not dpt.can_cast(src.dtype, dst.dtype, casting=casting):
564+
if not dpnp.can_cast(src.dtype, dst.dtype, casting=casting):
523565
raise TypeError(
524566
f"Cannot cast from {src.dtype} to {dst.dtype} "
525567
f"according to the rule {casting}."

tests/test_arraymanipulation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,3 +928,14 @@ def test_subok_error():
928928
with pytest.raises(NotImplementedError):
929929
dpnp.broadcast_arrays(x, subok=True)
930930
dpnp.broadcast_to(x, (4, 4), subok=True)
931+
932+
933+
def test_can_cast():
934+
X = dpnp.ones((2, 2), dtype=dpnp.int64)
935+
pytest.raises(TypeError, dpnp.can_cast, X, 1)
936+
pytest.raises(TypeError, dpnp.can_cast, X, X)
937+
938+
X_np = numpy.ones((2, 2), dtype=numpy.int64)
939+
assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32")
940+
assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32)
941+
assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import unittest
2+
3+
import numpy
4+
import pytest
5+
6+
import dpnp as cupy
7+
from tests.third_party.cupy import testing
8+
9+
10+
def _generate_type_routines_input(xp, dtype, obj_type):
11+
dtype = numpy.dtype(dtype)
12+
if obj_type == "dtype":
13+
return dtype
14+
if obj_type == "specifier":
15+
return str(dtype)
16+
if obj_type == "scalar":
17+
return dtype.type(3)
18+
if obj_type == "array":
19+
return xp.zeros(3, dtype=dtype)
20+
if obj_type == "primitive":
21+
return type(dtype.type(3).tolist())
22+
assert False
23+
24+
25+
@testing.parameterize(
26+
*testing.product(
27+
{
28+
"obj_type": ["dtype", "specifier", "scalar", "array", "primitive"],
29+
}
30+
)
31+
)
32+
class TestCanCast(unittest.TestCase):
33+
@testing.for_all_dtypes_combination(names=("from_dtype", "to_dtype"))
34+
@testing.numpy_cupy_equal()
35+
def test_can_cast(self, xp, from_dtype, to_dtype):
36+
if self.obj_type == "scalar":
37+
pytest.skip("to be aligned with NEP-50")
38+
39+
from_obj = _generate_type_routines_input(xp, from_dtype, self.obj_type)
40+
41+
ret = xp.can_cast(from_obj, to_dtype)
42+
assert isinstance(ret, bool)
43+
return ret
44+
45+
46+
@pytest.mark.skip("dpnp.common_type() is not implemented yet")
47+
class TestCommonType(unittest.TestCase):
48+
@testing.numpy_cupy_equal()
49+
def test_common_type_empty(self, xp):
50+
ret = xp.common_type()
51+
assert type(ret) == type
52+
return ret
53+
54+
@testing.for_all_dtypes(no_bool=True)
55+
@testing.numpy_cupy_equal()
56+
def test_common_type_single_argument(self, xp, dtype):
57+
array = _generate_type_routines_input(xp, dtype, "array")
58+
ret = xp.common_type(array)
59+
assert type(ret) == type
60+
return ret
61+
62+
@testing.for_all_dtypes_combination(
63+
names=("dtype1", "dtype2"), no_bool=True
64+
)
65+
@testing.numpy_cupy_equal()
66+
def test_common_type_two_arguments(self, xp, dtype1, dtype2):
67+
array1 = _generate_type_routines_input(xp, dtype1, "array")
68+
array2 = _generate_type_routines_input(xp, dtype2, "array")
69+
ret = xp.common_type(array1, array2)
70+
assert type(ret) == type
71+
return ret
72+
73+
@testing.for_all_dtypes()
74+
def test_common_type_bool(self, dtype):
75+
for xp in (numpy, cupy):
76+
array1 = _generate_type_routines_input(xp, dtype, "array")
77+
array2 = _generate_type_routines_input(xp, "bool_", "array")
78+
with pytest.raises(TypeError):
79+
xp.common_type(array1, array2)
80+
81+
82+
@testing.parameterize(
83+
*testing.product(
84+
{
85+
"obj_type1": ["dtype", "specifier", "scalar", "array", "primitive"],
86+
"obj_type2": ["dtype", "specifier", "scalar", "array", "primitive"],
87+
}
88+
)
89+
)
90+
class TestResultType(unittest.TestCase):
91+
@testing.for_all_dtypes_combination(names=("dtype1", "dtype2"))
92+
@testing.numpy_cupy_equal()
93+
def test_result_type(self, xp, dtype1, dtype2):
94+
if "scalar" in {self.obj_type1, self.obj_type2}:
95+
pytest.skip("to be aligned with NEP-50")
96+
97+
input1 = _generate_type_routines_input(xp, dtype1, self.obj_type1)
98+
99+
input2 = _generate_type_routines_input(xp, dtype2, self.obj_type2)
100+
ret = xp.result_type(input1, input2)
101+
assert isinstance(ret, numpy.dtype)
102+
return ret

0 commit comments

Comments
 (0)