Skip to content

Commit 33450f3

Browse files
committed
Use ValueError for different device errors
1 parent 78def19 commit 33450f3

8 files changed

+42
-42
lines changed

array_api_strict/_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _check_device(self, other):
221221
return
222222
elif isinstance(other, Array):
223223
if self.device != other.device:
224-
raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
224+
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
225225

226226
# Helper function to match the type promotion rules in the spec
227227
def _promote_scalar(self, scalar):

array_api_strict/_elementwise_functions.py

Lines changed: 29 additions & 29 deletions
Large diffs are not rendered by default.

array_api_strict/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
2323
if indices.ndim != 1:
2424
raise ValueError("Only 1-dim indices array is supported")
2525
if x.device != indices.device:
26-
raise RuntimeError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
26+
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.")
2727
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device)

array_api_strict/_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
8484
raise ValueError('cross() dimension must equal 3')
8585

8686
if x1.device != x2.device:
87-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
87+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
8888

8989
if get_array_api_strict_flags()['api_version'] >= '2023.12':
9090
if axis >= 0:
@@ -246,7 +246,7 @@ def outer(x1: Array, x2: Array, /) -> Array:
246246
raise ValueError('The input arrays to outer must be 1-dimensional')
247247

248248
if x1.device != x2.device:
249-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
249+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
250250

251251
return Array._new(np.outer(x1._array, x2._array), device=x1.device)
252252

@@ -357,7 +357,7 @@ def solve(x1: Array, x2: Array, /) -> Array:
357357
raise TypeError('Only floating-point dtypes are allowed in solve')
358358

359359
if x1.device != x2.device:
360-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
360+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
361361

362362
return Array._new(_solve(x1._array, x2._array), device=x1.device)
363363

array_api_strict/_linear_algebra_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def matmul(x1: Array, x2: Array, /) -> Array:
3131
raise TypeError('Only numeric dtypes are allowed in matmul')
3232

3333
if x1.device != x2.device:
34-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
34+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
3535

3636
return Array._new(np.matmul(x1._array, x2._array), device=x1.device)
3737

@@ -45,7 +45,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
4545
raise TypeError('Only numeric dtypes are allowed in tensordot')
4646

4747
if x1.device != x2.device:
48-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
48+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
4949

5050
return Array._new(np.tensordot(x1._array, x2._array, axes=axes), device=x1.device)
5151

@@ -68,7 +68,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
6868
raise ValueError("axis is out of bounds for x1 and x2")
6969

7070
if x1.device != x2.device:
71-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
71+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
7272

7373
# In versions of the standard prior to 2023.12, vecdot applied axis after
7474
# broadcasting. This is different from applying it before broadcasting

array_api_strict/_manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def repeat(
9494
if repeats.dtype not in _integer_dtypes:
9595
raise TypeError("The repeats array must have an integer dtype")
9696
if x.device != repeats.device:
97-
raise RuntimeError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.")
97+
raise ValueError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.")
9898
elif isinstance(repeats, int):
9999
repeats = asarray(repeats)
100100
else:

array_api_strict/_searching_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def searchsorted(
6363
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
6464

6565
if x1.device != x2.device:
66-
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
66+
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
6767

6868
sorter = sorter._array if sorter is not None else None
6969
# TODO: The sort order of nans and signed zeros is implementation

array_api_strict/tests/test_indexing_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def test_take_device():
3131

3232
x = xp.asarray([2, 3])
3333
indices = xp.asarray([1, 1, 0], device=xp.Device("device1"))
34-
with pytest.raises(RuntimeError, match="Arrays from two different devices"):
34+
with pytest.raises(ValueError, match="Arrays from two different devices"):
3535
xp.take(x, indices)
3636

3737
x = xp.asarray([2, 3], device=xp.Device("device1"))
3838
indices = xp.asarray([1, 1, 0])
39-
with pytest.raises(RuntimeError, match="Arrays from two different devices"):
39+
with pytest.raises(ValueError, match="Arrays from two different devices"):
4040
xp.take(x, indices)
4141

4242
x = xp.asarray([2, 3], device=xp.Device("device1"))
4343
indices = xp.asarray([1, 1, 0], device=xp.Device("device1"))
44-
xp.take(x, indices)
44+
xp.take(x, indices)

0 commit comments

Comments
 (0)