Skip to content

Commit 9b83bef

Browse files
authored
Merge pull request #1878 from IntelPython/improve-full-error-for-invalid-scalar-type
Improve `dpctl.tensor.full` error for invalid `fill_value`
2 parents 286afae + 7095358 commit 9b83bef

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
### Fixed
1919
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877)
2020

21+
* Improved error in constructors `tensor.full` and `tensor.full_like` when provided a non-numeric fill value [gh-1878](https://github.com/IntelPython/dpctl/pull/1878)
22+
2123
### Maintenance
2224

2325
* Update black version used in Python code style workflow [gh-1828](https://github.com/IntelPython/dpctl/pull/1828)

dpctl/tensor/_ctors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import operator
18+
from numbers import Number
1819

1920
import numpy as np
2021

@@ -1037,6 +1038,19 @@ def _cast_fill_val(fill_val, dt):
10371038
return fill_val
10381039

10391040

1041+
def _validate_fill_value(fill_val):
1042+
"""
1043+
Validates that `fill_val` is a numeric or boolean scalar.
1044+
"""
1045+
# TODO: verify if `np.True_` and `np.False_` should be instances of
1046+
# Number in NumPy, like other NumPy scalars and like Python bools
1047+
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
1048+
if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_):
1049+
raise TypeError(
1050+
f"array cannot be filled with scalar of type {type(fill_val)}"
1051+
)
1052+
1053+
10401054
def full(
10411055
shape,
10421056
fill_value,
@@ -1110,6 +1124,8 @@ def full(
11101124
sycl_queue=sycl_queue,
11111125
)
11121126
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
1127+
else:
1128+
_validate_fill_value(fill_value)
11131129

11141130
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
11151131
usm_type = usm_type if usm_type is not None else "device"
@@ -1480,6 +1496,8 @@ def full_like(
14801496
)
14811497
_manager.add_event_pair(hev, copy_ev)
14821498
return res
1499+
else:
1500+
_validate_fill_value(fill_value)
14831501

14841502
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
14851503
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,3 +2621,14 @@ def test_setitem_from_numpy_contig():
26212621

26222622
expected = dpt.reshape(dpt.arange(-10, 10, dtype=fp_dt), (4, 5))
26232623
assert dpt.all(dpt.flip(Xdpt, axis=-1) == expected)
2624+
2625+
2626+
def test_full_functions_raise_type_error():
2627+
get_queue_or_skip()
2628+
2629+
with pytest.raises(TypeError):
2630+
dpt.full(1, "0")
2631+
2632+
x = dpt.ones(1, dtype="i4")
2633+
with pytest.raises(TypeError):
2634+
dpt.full_like(x, "0")

0 commit comments

Comments
 (0)