Skip to content

Commit 7c6a178

Browse files
committed
lint
1 parent fd87c2f commit 7c6a178

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,21 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
126126
if (f2 is None) == (fill_value is None):
127127
msg = "Exactly one of `fill_value` or `f2` must be given."
128128
raise TypeError(msg)
129-
if not isinstance(args, tuple):
130-
args = (args,)
131-
args = cast(tuple[Array, ...], args)
129+
args_ = list(args) if isinstance(args, tuple) else [args]
130+
del args
132131

133-
xp = array_namespace(cond, *args) if xp is None else xp
132+
xp = array_namespace(cond, *args_) if xp is None else xp
134133

135134
if getattr(fill_value, "ndim", 0):
136-
cond, fill_value, *args = xp.broadcast_arrays(cond, fill_value, *args)
135+
cond, fill_value, *args_ = xp.broadcast_arrays(cond, fill_value, *args_)
137136
else:
138-
cond, *args = xp.broadcast_arrays(cond, *args)
137+
cond, *args_ = xp.broadcast_arrays(cond, *args_)
139138

140139
if is_dask_namespace(xp):
141-
meta_xp = meta_namespace(cond, *args, fill_value, xp=xp)
140+
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
142141
# map_blocks doesn't descend into tuples of Arrays
143-
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args, xp=meta_xp)
144-
return _apply_where(cond, f1, f2, fill_value, *args, xp=xp)
142+
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
143+
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
145144

146145

147146
def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,5 +264,5 @@ def meta_namespace(
264264
if not is_dask_namespace(xp):
265265
return xp
266266
# Quietly skip scalars and None's
267-
metas = [getattr(a, "_meta", None) for a in arrays]
267+
metas = [cast(Array | None, getattr(a, "_meta", None)) for a in arrays]
268268
return array_namespace(*metas)

tests/test_funcs.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import warnings
44
from types import ModuleType
5+
from typing import Any, cast
56

67
import hypothesis
78
import hypothesis.extra.numpy as npst
@@ -29,7 +30,7 @@
2930
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
3031
from array_api_extra._lib._utils._compat import device as get_device
3132
from array_api_extra._lib._utils._helpers import asarrays, eager_shape, ndindex
32-
from array_api_extra._lib._utils._typing import Array, Device, DType
33+
from array_api_extra._lib._utils._typing import Array, Device
3334
from array_api_extra.testing import lazy_xp_function
3435

3536
# some xp backends are untyped
@@ -120,7 +121,7 @@ def test_dtype_propagation(self, xp: ModuleType, library: Backend):
120121
cond,
121122
(x, y),
122123
self.f1,
123-
lambda x, y: mxp.astype(x - y, xp.int64), # pyright: ignore[reportUnknownArgumentType]
124+
lambda x, y: mxp.astype(x - y, xp.int64), # pyright: ignore[reportArgumentType,reportUnknownArgumentType]
124125
)
125126
assert actual.dtype == xp.int64
126127

@@ -212,11 +213,11 @@ def test_device(self, xp: ModuleType, device: Device):
212213
p=st.floats(min_value=0, max_value=1),
213214
data=st.data(),
214215
)
215-
def test_hypothesis( # type: ignore[no-any-decorated]
216+
def test_hypothesis( # type: ignore[no-any-explicit,no-any-decorated]
216217
self,
217218
n_arrays: int,
218219
rng_seed: int,
219-
dtype: DType,
220+
dtype: np.dtype[Any],
220221
p: float,
221222
data: st.DataObject,
222223
xp: ModuleType,
@@ -242,10 +243,10 @@ def test_hypothesis( # type: ignore[no-any-decorated]
242243
)
243244

244245
def f1(*args: Array) -> Array:
245-
return sum(args)
246+
return cast(Array, sum(args))
246247

247248
def f2(*args: Array) -> Array:
248-
return sum(args) / 2
249+
return cast(Array, sum(args) / 2)
249250

250251
rng = np.random.default_rng(rng_seed)
251252
cond = xp.asarray(rng.random(size=cond_shape) > p)

tests/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# mypy: disable-error-code=no-untyped-usage
2222

23-
np_compat = array_namespace(np.empty(0))
23+
np_compat = array_namespace(np.empty(0)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
2424

2525
# FIXME calls xp.unique_values without size
2626
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))

0 commit comments

Comments
 (0)