Skip to content

Commit 635e14d

Browse files
committed
Merge branch 'main' into betatim-multiple-devices
2 parents 8e6365b + c0c7303 commit 635e14d

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

array_api_strict/_elementwise_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def conj(x: Array, /) -> Array:
342342
"""
343343
if x.dtype not in _complex_floating_dtypes:
344344
raise TypeError("Only complex floating-point dtypes are allowed in conj")
345-
return Array._new(np.conj(x), device=x.device)
345+
return Array._new(np.conj(x._array), device=x.device)
346346

347347
@requires_api_version('2023.12')
348348
def copysign(x1: Array, x2: Array, /) -> Array:
@@ -520,7 +520,7 @@ def imag(x: Array, /) -> Array:
520520
"""
521521
if x.dtype not in _complex_floating_dtypes:
522522
raise TypeError("Only complex floating-point dtypes are allowed in imag")
523-
return Array._new(np.imag(x), device=x.device)
523+
return Array._new(np.imag(x._array), device=x.device)
524524

525525

526526
def isfinite(x: Array, /) -> Array:
@@ -817,7 +817,7 @@ def real(x: Array, /) -> Array:
817817
"""
818818
if x.dtype not in _complex_floating_dtypes:
819819
raise TypeError("Only complex floating-point dtypes are allowed in real")
820-
return Array._new(np.real(x), device=x.device)
820+
return Array._new(np.real(x._array), device=x.device)
821821

822822

823823
def remainder(x1: Array, x2: Array, /) -> Array:

array_api_strict/_manipulation_functions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from ._array_object import Array
44
from ._creation_functions import asarray
5-
from ._data_type_functions import result_type
6-
from ._dtypes import _integer_dtypes
5+
from ._data_type_functions import astype, result_type
6+
from ._dtypes import _integer_dtypes, int64, uint64
77
from ._flags import requires_api_version, get_array_api_strict_flags
88

99
from typing import TYPE_CHECKING
@@ -98,7 +98,13 @@ def repeat(
9898
else:
9999
raise TypeError("repeats must be an int or array")
100100

101-
return Array._new(np.repeat(x._array, repeats, axis=axis), device=x.device)
101+
if repeats.dtype == uint64:
102+
# NumPy does not allow uint64 because can't be cast down to x.dtype
103+
# with 'safe' casting. However, repeats values larger than 2**63 are
104+
# infeasable, and even if they are present by mistake, this will
105+
# lead to underflow and an error.
106+
repeats = astype(repeats, int64)
107+
return Array._new(np.repeat(x._array, repeats._array, axis=axis), device=x.device)
102108

103109
# Note: the optional argument is called 'shape', not 'newshape'
104110
def reshape(x: Array,

0 commit comments

Comments
 (0)