@@ -32,7 +32,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
32
32
dtypes : Sequence [DType ] | None = None ,
33
33
xp : ModuleType | None = None ,
34
34
input_indices : Sequence [Sequence [Hashable ]] | None = None ,
35
- core_indices : Sequence [Sequence [ Hashable ] ] | None = None ,
35
+ core_indices : Sequence [Hashable ] | None = None ,
36
36
output_indices : Sequence [Sequence [Hashable ]] | None = None ,
37
37
adjust_chunks : Sequence [dict [Hashable , Callable [[int ], int ]]] | None = None ,
38
38
new_axes : Sequence [dict [Hashable , int ]] | None = None ,
@@ -70,9 +70,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
70
70
ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
71
71
(1,)]``.
72
72
Default: disallow Dask.
73
- core_indices : Sequence[Sequence[ Hashable] ], optional
73
+ core_indices : Sequence[Hashable], optional
74
74
**Dask specific.**
75
- Axes labels of each input array that cannot be broken into chunks.
75
+ Axes of the input arrays that cannot be broken into chunks.
76
76
Default: disallow Dask.
77
77
output_indices : Sequence[Sequence[Hashable]], optional
78
78
**Dask specific.**
@@ -144,7 +144,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
144
144
145
145
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
146
146
... input_indices=['ij'], output_indices=['ij'],
147
- ... core_indices=[ 'i'] )
147
+ ... core_indices='i')
148
148
149
149
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150
150
along multiple chunks.
@@ -177,9 +177,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177
177
if len (input_indices ) != len (args ):
178
178
msg = f"got { len (input_indices )} input_indices and { len (args )} args"
179
179
raise ValueError (msg )
180
- if len (core_indices ) != len (args ):
181
- msg = f"got { len (core_indices )} input_indices and { len (args )} args"
182
- raise ValueError (msg )
183
180
if len (output_indices ) != len (shapes ):
184
181
msg = f"got { len (output_indices )} input_indices and { len (shapes )} shapes"
185
182
raise NotImplementedError (msg )
@@ -197,19 +194,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
197
194
raise ValueError (msg )
198
195
199
196
# core_indices validation
200
- for core_idx , inp_idx , arg in zip (
201
- core_indices , input_indices , args , strict = True
202
- ):
203
- for i in core_idx :
204
- try :
205
- axis = list (inp_idx ).index (i )
206
- except ValueError :
207
- msg = (
208
- f"Index { i } found in core indices but not in "
209
- "matching input_indices"
210
- )
211
- raise ValueError (msg ) from None
212
- if len (arg .chunks [axis ]) > 1 :
197
+ for inp_idx , arg in zip (input_indices , args , strict = True ):
198
+ for i , chunks in zip (inp_idx , arg .chunks , strict = True ):
199
+ if i in core_indices and len (chunks ) > 1 :
213
200
msg = f"Core index { i } is broken into multiple chunks"
214
201
raise ValueError (msg )
215
202
0 commit comments