Skip to content

Commit 48faeac

Browse files
committed
Rearrange fill_value type check logic and implement in full_like
1 parent 9aa1842 commit 48faeac

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

dpctl/tensor/_ctors.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,14 +1111,15 @@ def full(
11111111
sycl_queue=sycl_queue,
11121112
)
11131113
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
1114-
1115-
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1116-
usm_type = usm_type if usm_type is not None else "device"
1117-
if not isinstance(fill_value, Number):
1114+
elif not isinstance(fill_value, Number):
11181115
raise TypeError(
11191116
"`full` array cannot be constructed with value of type "
11201117
f"{type(fill_value)}"
11211118
)
1119+
1120+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1121+
usm_type = usm_type if usm_type is not None else "device"
1122+
11221123
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
11231124
res = dpt.usm_ndarray(
11241125
shape,
@@ -1486,6 +1487,11 @@ def full_like(
14861487
)
14871488
_manager.add_event_pair(hev, copy_ev)
14881489
return res
1490+
elif not isinstance(fill_value, Number):
1491+
raise TypeError(
1492+
"`full` array cannot be constructed with value of type "
1493+
f"{type(fill_value)}"
1494+
)
14891495

14901496
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
14911497
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2623,8 +2623,12 @@ def test_setitem_from_numpy_contig():
26232623
assert dpt.all(dpt.flip(Xdpt, axis=-1) == expected)
26242624

26252625

2626-
def test_full_raises_type_error():
2626+
def test_full_functions_raise_type_error():
26272627
get_queue_or_skip()
26282628

26292629
with pytest.raises(TypeError):
26302630
dpt.full(1, "0")
2631+
2632+
x = dpt.ones(1, dtype="i4")
2633+
with pytest.raises(TypeError):
2634+
dpt.full_like(x, "0")

0 commit comments

Comments
 (0)