Skip to content

Commit 03c7615

Browse files
authored
dpctl.tensor.tile returns a scalar for 0D (scalar) input and empty repetitions (#1628)
Previously, this case would return a 1D array of size 1, which did not match Numpy or the array API spec's expected behavior
1 parent 65bb9ef commit 03c7615

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -964,18 +964,6 @@ def tile(x, repetitions):
964964
f"Expected tuple or integer type, got {type(repetitions)}."
965965
)
966966

967-
# case of scalar
968-
if x.size == 1:
969-
if not repetitions:
970-
# handle empty tuple
971-
repetitions = (1,)
972-
return dpt.full(
973-
repetitions,
974-
x,
975-
dtype=x.dtype,
976-
usm_type=x.usm_type,
977-
sycl_queue=x.sycl_queue,
978-
)
979967
rep_dims = len(repetitions)
980968
x_dims = x.ndim
981969
if rep_dims < x_dims:

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,20 +1378,24 @@ def test_tile_size_1():
13781378

13791379
reps = 5
13801380
# test for 0d array
1381-
x = dpt.asarray(2, dtype="i4")
1382-
res = dpt.tile(x, reps)
1381+
x1 = dpt.asarray(2, dtype="i4")
1382+
res = dpt.tile(x1, reps)
13831383
assert dpt.all(res == dpt.full(reps, 2, dtype="i4"))
13841384

13851385
# test for 1d array with single element
1386-
x = dpt.asarray([2], dtype="i4")
1387-
res = dpt.tile(x, reps)
1386+
x2 = dpt.asarray([2], dtype="i4")
1387+
res = dpt.tile(x2, reps)
13881388
assert dpt.all(res == dpt.full(reps, 2, dtype="i4"))
13891389

1390-
# test empty reps returns copy of input
13911390
reps = ()
1392-
res = dpt.tile(x, reps)
1393-
assert x.shape == res.shape
1394-
assert x == res
1391+
# test for gh-1627 behavior
1392+
res = dpt.tile(x1, reps)
1393+
assert x1.shape == res.shape
1394+
assert x1 == res
1395+
1396+
res = dpt.tile(x2, reps)
1397+
assert x2.shape == res.shape
1398+
assert x2 == res
13951399

13961400

13971401
def test_tile_prepends_axes():

0 commit comments

Comments
 (0)