Skip to content

Commit 8ed8ef2

Browse files
Merge pull request #1504 from IntelPython/fix-copy-of-src-with-offset
Fix copy of src with offset
2 parents 40f130e + 5cd7ba2 commit 8ed8ef2

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,22 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
300300
src.shape, src.strides, len(common_shape)
301301
)
302302
src_same_shape = dpt.usm_ndarray(
303-
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
303+
common_shape,
304+
dtype=src.dtype,
305+
buffer=src,
306+
strides=new_src_strides,
307+
offset=src._element_offset,
304308
)
305309
elif src.ndim == len(common_shape):
306310
new_src_strides = _broadcast_strides(
307311
src.shape, src.strides, len(common_shape)
308312
)
309313
src_same_shape = dpt.usm_ndarray(
310-
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
314+
common_shape,
315+
dtype=src.dtype,
316+
buffer=src,
317+
strides=new_src_strides,
318+
offset=src._element_offset,
311319
)
312320
else:
313321
# since broadcasting succeeded, src.ndim is greater because of

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
10361036

10371037

10381038
def test_setitem_broadcasting():
1039+
"See gh-1503"
10391040
get_queue_or_skip()
10401041
dst = dpt.ones((2, 3, 4), dtype="u4")
10411042
src = dpt.zeros((3, 1), dtype=dst.dtype)
@@ -1044,6 +1045,16 @@ def test_setitem_broadcasting():
10441045
assert np.array_equal(dpt.asnumpy(dst), expected)
10451046

10461047

1048+
def test_setitem_broadcasting_offset():
1049+
get_queue_or_skip()
1050+
dt = dpt.int32
1051+
x = dpt.asarray([[1, 2, 3], [6, 7, 8]], dtype=dt)
1052+
y = dpt.asarray([4, 5], dtype=dt)
1053+
x[0] = y[1]
1054+
expected = dpt.asarray([[5, 5, 5], [6, 7, 8]], dtype=dt)
1055+
assert dpt.all(x == expected)
1056+
1057+
10471058
def test_setitem_broadcasting_empty_dst_validation():
10481059
"Broadcasting rules apply, except exception"
10491060
get_queue_or_skip()

0 commit comments

Comments
 (0)