Skip to content

Commit 78def19

Browse files
committed
Add device check to repeat()
1 parent 635e14d commit 78def19

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_strict/_manipulation_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def repeat(
9393
raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
9494
if repeats.dtype not in _integer_dtypes:
9595
raise TypeError("The repeats array must have an integer dtype")
96+
if x.device != repeats.device:
97+
raise RuntimeError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.")
9698
elif isinstance(repeats, int):
9799
repeats = asarray(repeats)
98100
else:

0 commit comments

Comments
 (0)