Skip to content

Commit 0dbabcc

Browse files
committed
Fix result device
1 parent 9c5436c commit 0dbabcc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

array_api_strict/_manipulation_functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ def concat(
2727
dtype = result_type(*arrays)
2828
if len({a.device for a in arrays}) > 1:
2929
raise ValueError("concat inputs must all be on the same device")
30+
result_device = arrays[0].device
3031

3132
arrays = tuple(a._array for a in arrays)
32-
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=arrays[0].device)
33+
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device)
3334

3435

3536
def expand_dims(x: Array, /, *, axis: int) -> Array:
@@ -157,8 +158,9 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
157158
result_type(*arrays)
158159
if len({a.device for a in arrays}) > 1:
159160
raise ValueError("concat inputs must all be on the same device")
161+
result_device = arrays[0].device
160162
arrays = tuple(a._array for a in arrays)
161-
return Array._new(np.stack(arrays, axis=axis), device=arrays[0].device)
163+
return Array._new(np.stack(arrays, axis=axis), device=result_device)
162164

163165

164166
@requires_api_version('2023.12')

0 commit comments

Comments
 (0)