Skip to content

Commit 556a5c6

Browse files
authored
Merge pull request #2097 from IntelPython/resolve-gh-1046
Resolve gh-1046
2 parents a26cac1 + 676e418 commit 556a5c6

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def _asarray_from_usm_ndarray(
111111
raise TypeError(
112112
f"Expected dpctl.tensor.usm_ndarray, got {type(usm_ndary)}"
113113
)
114-
if dtype is None:
115-
dtype = usm_ndary.dtype
116114
if usm_type is None:
117115
usm_type = usm_ndary.usm_type
118116
if sycl_queue is not None:
@@ -122,6 +120,8 @@ def _asarray_from_usm_ndarray(
122120
copy_q = normalize_queue_device(sycl_queue=sycl_queue, device=exec_q)
123121
else:
124122
copy_q = usm_ndary.sycl_queue
123+
if dtype is None:
124+
dtype = _map_to_device_dtype(usm_ndary.dtype, copy_q)
125125
# Conditions for zero copy:
126126
can_zero_copy = copy is not True
127127
# dtype is unchanged

dpctl/tests/test_tensor_asarray.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,27 @@ def test_asarray_support_for_usm_ndarray_protocol(usm_type):
623623
assert x.dtype == y3.dtype
624624
assert y3.usm_data.reference_obj is None
625625
assert dpt.all(x[dpt.newaxis, :] == y3)
626+
627+
628+
@pytest.mark.parametrize("dt", [dpt.float16, dpt.float64, dpt.complex128])
629+
def test_asarray_to_device_with_unsupported_dtype(dt):
630+
aspect = "fp16" if dt == dpt.float16 else "fp64"
631+
try:
632+
d0 = dpctl.select_device_with_aspects(aspect)
633+
except dpctl.SyclDeviceCreationError:
634+
pytest.skip("No device with aspect for test")
635+
d1 = None
636+
for d in dpctl.get_devices():
637+
if d.default_selector_score < 0:
638+
pass
639+
try:
640+
d1 = dpctl.select_device_with_aspects(
641+
d.device_type.name, excluded_aspects=[aspect]
642+
)
643+
except dpctl.SyclDeviceCreationError:
644+
pass
645+
if d1 is None:
646+
pytest.skip("No device with missing aspect for test")
647+
x = dpt.ones(10, dtype=dt, device=d0)
648+
y = dpt.asarray(x, device=d1)
649+
assert y.sycl_device == d1

0 commit comments

Comments
 (0)