Skip to content

Commit f8e058c

Browse files
committed
Improve dpnp_fill array/scalar path logic
1 parent f2b3cfa commit f8e058c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

dpnp/dpnp_algo/dpnp_fill.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from dpctl.utils import SequentialOrderManager
3535

3636
import dpnp
37-
from dpnp.dpnp_array import dpnp_array
3837

3938

4039
def dpnp_fill(arr, val):
@@ -43,7 +42,7 @@ def dpnp_fill(arr, val):
4342

4443
dpnp.check_supported_arrays_type(val, scalar_type=True, all_scalars=True)
4544
# if val is an array, process it
46-
if isinstance(val, (dpnp_array, dpt.usm_ndarray)):
45+
if dpnp.is_supported_array_type(val):
4746
val = dpnp.get_usm_ndarray(val)
4847
if val.shape != ():
4948
raise ValueError("`val` must be a scalar")
@@ -61,6 +60,10 @@ def dpnp_fill(arr, val):
6160
)
6261
_manager.add_event_pair(h_ev, c_ev)
6362
return
63+
elif not dpnp.isscalar(val):
64+
raise TypeError(
65+
f"Expected `val` to be an array or Python scalar, got {type(val)}"
66+
)
6467

6568
dt = arr.dtype
6669
val_type = type(val)

0 commit comments

Comments
 (0)