Skip to content

Commit 49ce014

Browse files
committed
fix core_indices
1 parent fab197d commit 49ce014

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

src/array_api_extra/_apply.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3232
dtypes: Sequence[DType] | None = None,
3333
xp: ModuleType | None = None,
3434
input_indices: Sequence[Sequence[Hashable]] | None = None,
35-
core_indices: Sequence[Sequence[Hashable]] | None = None,
35+
core_indices: Sequence[Hashable] | None = None,
3636
output_indices: Sequence[Sequence[Hashable]] | None = None,
3737
adjust_chunks: Sequence[dict[Hashable, Callable[[int], int]]] | None = None,
3838
new_axes: Sequence[dict[Hashable, int]] | None = None,
@@ -70,9 +70,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
7070
ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
7171
(1,)]``.
7272
Default: disallow Dask.
73-
core_indices : Sequence[Sequence[Hashable]], optional
73+
core_indices : Sequence[Hashable], optional
7474
**Dask specific.**
75-
Axes labels of each input array that cannot be broken into chunks.
75+
Axes of the input arrays that cannot be broken into chunks.
7676
Default: disallow Dask.
7777
output_indices : Sequence[Sequence[Hashable]], optional
7878
**Dask specific.**
@@ -144,7 +144,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
144144
145145
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
146146
... input_indices=['ij'], output_indices=['ij'],
147-
... core_indices=['i'])
147+
... core_indices='i')
148148
149149
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150150
along multiple chunks.
@@ -177,9 +177,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177177
if len(input_indices) != len(args):
178178
msg = f"got {len(input_indices)} input_indices and {len(args)} args"
179179
raise ValueError(msg)
180-
if len(core_indices) != len(args):
181-
msg = f"got {len(core_indices)} input_indices and {len(args)} args"
182-
raise ValueError(msg)
183180
if len(output_indices) != len(shapes):
184181
msg = f"got {len(output_indices)} input_indices and {len(shapes)} shapes"
185182
raise NotImplementedError(msg)
@@ -197,19 +194,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
197194
raise ValueError(msg)
198195

199196
# core_indices validation
200-
for core_idx, inp_idx, arg in zip(
201-
core_indices, input_indices, args, strict=True
202-
):
203-
for i in core_idx:
204-
try:
205-
axis = list(inp_idx).index(i)
206-
except ValueError:
207-
msg = (
208-
f"Index {i} found in core indices but not in "
209-
"matching input_indices"
210-
)
211-
raise ValueError(msg) from None
212-
if len(arg.chunks[axis]) > 1:
197+
for inp_idx, arg in zip(input_indices, args, strict=True):
198+
for i, chunks in zip(inp_idx, arg.chunks, strict=True):
199+
if i in core_indices and len(chunks) > 1:
213200
msg = f"Core index {i} is broken into multiple chunks"
214201
raise ValueError(msg)
215202

0 commit comments

Comments
 (0)