Skip to content

Commit bbe05a5

Browse files
committed
Merge branch 'main' into lazywhere
2 parents b3808e7 + 308fc1f commit bbe05a5

File tree

15 files changed

+326
-159
lines changed

15 files changed

+326
-159
lines changed

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ enable_error_code = ["ignore-without-code", "truthy-bool"]
205205
# https://github.com/data-apis/array-api-typing
206206
disallow_any_expr = false
207207
# false positives with input validation
208-
disable_error_code = ["redundant-expr", "unreachable"]
208+
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
209209

210210
[[tool.mypy.overrides]]
211211
# slow/unavailable on Windows; do not add to the lint env

src/array_api_extra/_lib/_at.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_jax_array,
1616
is_writeable_array,
1717
)
18+
from ._utils._typing import Array, SetIndex
1819
from ._utils._helpers import meta_namespace
1920
from ._utils._typing import Array, Index
2021

@@ -44,7 +45,13 @@ def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[
4445
return self.value
4546

4647

47-
_undef = object()
48+
class Undef(Enum):
49+
"""Sentinel for undefined values."""
50+
51+
UNDEF = 0
52+
53+
54+
_undef = Undef.UNDEF
4855

4956

5057
class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
@@ -189,16 +196,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
189196
"""
190197

191198
_x: Array
192-
_idx: Index
199+
_idx: SetIndex | Undef
193200
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")
194201

195202
def __init__(
196-
self, x: Array, idx: Index = _undef, /
203+
self, x: Array, idx: SetIndex | Undef = _undef, /
197204
) -> None: # numpydoc ignore=GL08
198205
self._x = x
199206
self._idx = idx
200207

201-
def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
208+
def __getitem__(self, idx: SetIndex, /) -> at: # numpydoc ignore=PR01,RT01
202209
"""
203210
Allow for the alternate syntax ``at(x)[start:stop:step]``.
204211
@@ -213,9 +220,9 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
213220
def _op(
214221
self,
215222
at_op: _AtOp,
216-
in_place_op: Callable[[Array, Array | object], Array] | None,
223+
in_place_op: Callable[[Array, Array | complex], Array] | None,
217224
out_of_place_op: Callable[[Array, Array], Array] | None,
218-
y: Array | object,
225+
y: Array | complex,
219226
/,
220227
copy: bool | None,
221228
xp: ModuleType | None,
@@ -227,7 +234,7 @@ def _op(
227234
----------
228235
at_op : _AtOp
229236
Method of JAX's Array.at[].
230-
in_place_op : Callable[[Array, Array | object], Array] | None
237+
in_place_op : Callable[[Array, Array | complex], Array] | None
231238
In-place operation to apply on mutable backends::
232239
233240
x[idx] = in_place_op(x[idx], y)
@@ -246,7 +253,7 @@ def _op(
246253
247254
x = xp.where(idx, y, x)
248255
249-
y : array or object
256+
y : array or complex
250257
Right-hand side of the operation.
251258
copy : bool or None
252259
Whether to copy the input array. See the class docstring for details.
@@ -263,7 +270,7 @@ def _op(
263270
x, idx = self._x, self._idx
264271
xp = array_namespace(x, y) if xp is None else xp
265272

266-
if idx is _undef:
273+
if isinstance(idx, Undef):
267274
msg = (
268275
"Index has not been set.\n"
269276
"Usage: either\n"
@@ -311,7 +318,10 @@ def _op(
311318
if copy or (copy is None and not writeable):
312319
if is_jax_array(x):
313320
# Use JAX's at[]
314-
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
321+
func = cast(
322+
Callable[[Array | complex], Array],
323+
getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType]
324+
)
315325
out = func(y)
316326
# Undo int->float promotion on JAX after _AtOp.DIVIDE
317327
return xp.astype(out, x.dtype, copy=False)
@@ -320,10 +330,10 @@ def _op(
320330
# with a copy followed by an update
321331

322332
x = xp.asarray(x, copy=True)
323-
if writeable is False:
324-
# A copy of a read-only numpy array is writeable
325-
# Note: this assumes that a copy of a writeable array is writeable
326-
writeable = None
333+
# A copy of a read-only numpy array is writeable
334+
# Note: this assumes that a copy of a writeable array is writeable
335+
assert not writeable
336+
writeable = None
327337

328338
if writeable is None:
329339
writeable = is_writeable_array(x)
@@ -333,14 +343,14 @@ def _op(
333343
raise ValueError(msg)
334344

335345
if in_place_op: # add(), subtract(), ...
336-
x[self._idx] = in_place_op(x[self._idx], y)
346+
x[idx] = in_place_op(x[idx], y)
337347
else: # set()
338-
x[self._idx] = y
348+
x[idx] = y
339349
return x
340350

341351
def set(
342352
self,
343-
y: Array | object,
353+
y: Array | complex,
344354
/,
345355
copy: bool | None = None,
346356
xp: ModuleType | None = None,
@@ -350,7 +360,7 @@ def set(
350360

351361
def add(
352362
self,
353-
y: Array | object,
363+
y: Array | complex,
354364
/,
355365
copy: bool | None = None,
356366
xp: ModuleType | None = None,
@@ -364,7 +374,7 @@ def add(
364374

365375
def subtract(
366376
self,
367-
y: Array | object,
377+
y: Array | complex,
368378
/,
369379
copy: bool | None = None,
370380
xp: ModuleType | None = None,
@@ -376,7 +386,7 @@ def subtract(
376386

377387
def multiply(
378388
self,
379-
y: Array | object,
389+
y: Array | complex,
380390
/,
381391
copy: bool | None = None,
382392
xp: ModuleType | None = None,
@@ -388,7 +398,7 @@ def multiply(
388398

389399
def divide(
390400
self,
391-
y: Array | object,
401+
y: Array | complex,
392402
/,
393403
copy: bool | None = None,
394404
xp: ModuleType | None = None,
@@ -400,7 +410,7 @@ def divide(
400410

401411
def power(
402412
self,
403-
y: Array | object,
413+
y: Array | complex,
404414
/,
405415
copy: bool | None = None,
406416
xp: ModuleType | None = None,
@@ -410,7 +420,7 @@ def power(
410420

411421
def min(
412422
self,
413-
y: Array | object,
423+
y: Array | complex,
414424
/,
415425
copy: bool | None = None,
416426
xp: ModuleType | None = None,
@@ -429,7 +439,7 @@ def min(
429439

430440
def max(
431441
self,
432-
y: Array | object,
442+
y: Array | complex,
433443
/,
434444
copy: bool | None = None,
435445
xp: ModuleType | None = None,

src/array_api_extra/_lib/_funcs.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import warnings
88
from collections.abc import Callable, Sequence
99
from types import ModuleType
10-
from typing import TYPE_CHECKING, cast, overload
10+
from typing import cast
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
14+
from ._utils._compat import array_namespace, is_jax_array
15+
from ._utils._helpers import asarrays, eager_shape, ndindex
1416
from ._utils._compat import (
1517
array_namespace,
1618
is_dask_namespace,
@@ -358,11 +360,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
358360
m = xp.astype(m, dtype)
359361

360362
avg = _helpers.mean(m, axis=1, xp=xp)
361-
fact = m.shape[1] - 1
363+
364+
m_shape = eager_shape(m)
365+
fact = m_shape[1] - 1
362366

363367
if fact <= 0:
364368
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
365-
fact = 0.0
369+
fact = 0
366370

367371
m -= avg[:, None]
368372
m_transpose = m.T
@@ -421,8 +425,10 @@ def create_diagonal(
421425
if x.ndim == 0:
422426
err_msg = "`x` must be at least 1-dimensional."
423427
raise ValueError(err_msg)
424-
batch_dims = x.shape[:-1]
425-
n = x.shape[-1] + abs(offset)
428+
429+
x_shape = eager_shape(x)
430+
batch_dims = x_shape[:-1]
431+
n = x_shape[-1] + abs(offset)
426432
diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x))
427433

428434
target_slice = slice(
@@ -532,10 +538,6 @@ def isclose(
532538
) -> Array: # numpydoc ignore=PR01,RT01
533539
"""See docstring in array_api_extra._delegation."""
534540
a, b = asarrays(a, b, xp=xp)
535-
# FIXME https://github.com/microsoft/pyright/issues/10085
536-
if TYPE_CHECKING: # pragma: nocover
537-
assert _compat.is_array_api_obj(a)
538-
assert _compat.is_array_api_obj(b)
539541

540542
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
541543
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -655,24 +657,17 @@ def kron(
655657
if xp is None:
656658
xp = array_namespace(a, b)
657659
a, b = asarrays(a, b, xp=xp)
658-
# FIXME https://github.com/microsoft/pyright/issues/10085
659-
if TYPE_CHECKING: # pragma: nocover
660-
assert _compat.is_array_api_obj(a)
661-
assert _compat.is_array_api_obj(b)
662660

663661
singletons = (1,) * (b.ndim - a.ndim)
664-
a = xp.broadcast_to(a, singletons + a.shape)
665-
# FIXME https://github.com/microsoft/pyright/issues/10085
666-
if TYPE_CHECKING: # pragma: nocover
667-
assert _compat.is_array_api_obj(a)
662+
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))
668663

669664
nd_b, nd_a = b.ndim, a.ndim
670665
nd_max = max(nd_b, nd_a)
671666
if nd_a == 0 or nd_b == 0:
672667
return xp.multiply(a, b)
673668

674-
a_shape = a.shape
675-
b_shape = b.shape
669+
a_shape = eager_shape(a)
670+
b_shape = eager_shape(b)
676671

677672
# Equalise the shapes by prepending smaller one with 1s
678673
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
@@ -737,16 +732,14 @@ def pad(
737732
) -> Array: # numpydoc ignore=PR01,RT01
738733
"""See docstring in `array_api_extra._delegation.py`."""
739734
# make pad_width a list of length-2 tuples of ints
740-
x_ndim = cast(int, x.ndim)
741-
742735
if isinstance(pad_width, int):
743-
pad_width_seq = [(pad_width, pad_width)] * x_ndim
736+
pad_width_seq = [(pad_width, pad_width)] * x.ndim
744737
elif (
745738
isinstance(pad_width, tuple)
746739
and len(pad_width) == 2
747740
and all(isinstance(i, int) for i in pad_width)
748741
):
749-
pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim
742+
pad_width_seq = [cast(tuple[int, int], pad_width)] * x.ndim
750743
else:
751744
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
752745

@@ -758,7 +751,8 @@ def pad(
758751
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
759752
raise ValueError(msg)
760753

761-
sh = x.shape[ax]
754+
sh = eager_shape(x)[ax]
755+
762756
if w_tpl[0] == 0 and w_tpl[1] == 0:
763757
sl = slice(None, None, None)
764758
else:
@@ -824,20 +818,17 @@ def setdiff1d(
824818
"""
825819
if xp is None:
826820
xp = array_namespace(x1, x2)
827-
x1, x2 = asarrays(x1, x2, xp=xp)
821+
# https://github.com/microsoft/pyright/issues/10103
822+
x1_, x2_ = asarrays(x1, x2, xp=xp)
828823

829824
if assume_unique:
830-
x1 = xp.reshape(x1, (-1,))
831-
x2 = xp.reshape(x2, (-1,))
825+
x1_ = xp.reshape(x1_, (-1,))
826+
x2_ = xp.reshape(x2_, (-1,))
832827
else:
833-
x1 = xp.unique_values(x1)
834-
x2 = xp.unique_values(x2)
835-
836-
# FIXME https://github.com/microsoft/pyright/issues/10085
837-
if TYPE_CHECKING: # pragma: nocover
838-
assert _compat.is_array_api_obj(x1)
828+
x1_ = xp.unique_values(x1_)
829+
x2_ = xp.unique_values(x2_)
839830

840-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
831+
return x1_[_helpers.in1d(x1_, x2_, assume_unique=True, invert=True, xp=xp)]
841832

842833

843834
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

0 commit comments

Comments
 (0)