@@ -191,26 +191,23 @@ def take_1d(arr, indexer, out=None, fill_value=np.nan):
191
191
192
192
n = len (indexer )
193
193
194
- if not isinstance (indexer , np .ndarray ):
195
- # Cython methods expects 32-bit integers
196
- indexer = np .array (indexer , dtype = np .int32 )
197
-
198
194
indexer = _ensure_int64 (indexer )
195
+
199
196
out_passed = out is not None
200
197
take_f = _take1d_dict .get (dtype_str )
201
198
202
199
if dtype_str in ('int32' , 'int64' , 'bool' ):
203
200
try :
204
201
if out is None :
205
202
out = np .empty (n , dtype = arr .dtype )
206
- take_f (arr , indexer , out = out , fill_value = fill_value )
203
+ take_f (arr , _ensure_int64 ( indexer ) , out = out , fill_value = fill_value )
207
204
except ValueError :
208
205
mask = indexer == - 1
209
206
if len (arr ) == 0 :
210
207
if not out_passed :
211
208
out = np .empty (n , dtype = arr .dtype )
212
209
else :
213
- out = arr . take ( indexer , out = out )
210
+ out = _ndtake ( arr , indexer , out = out )
214
211
if mask .any ():
215
212
if out_passed :
216
213
raise Exception ('out with dtype %s does not support NA' %
@@ -220,9 +217,9 @@ def take_1d(arr, indexer, out=None, fill_value=np.nan):
220
217
elif dtype_str in ('float64' , 'object' , 'datetime64[us]' ):
221
218
if out is None :
222
219
out = np .empty (n , dtype = arr .dtype )
223
- take_f (arr , indexer , out = out , fill_value = fill_value )
220
+ take_f (arr , _ensure_int64 ( indexer ) , out = out , fill_value = fill_value )
224
221
else :
225
- out = arr . take ( indexer , out = out )
222
+ out = _ndtake ( arr , indexer , out = out )
226
223
mask = indexer == - 1
227
224
if mask .any ():
228
225
if out_passed :
@@ -239,9 +236,6 @@ def take_2d_multi(arr, row_idx, col_idx, fill_value=np.nan):
239
236
240
237
take_f = _get_take2d_function (dtype_str , axis = 'multi' )
241
238
242
- row_idx = _ensure_int64 (row_idx )
243
- col_idx = _ensure_int64 (col_idx )
244
-
245
239
out_shape = len (row_idx ), len (col_idx )
246
240
247
241
if dtype_str in ('int32' , 'int64' , 'bool' ):
@@ -254,11 +248,14 @@ def take_2d_multi(arr, row_idx, col_idx, fill_value=np.nan):
254
248
fill_value = fill_value )
255
249
else :
256
250
out = np .empty (out_shape , dtype = arr .dtype )
257
- take_f (arr , row_idx , col_idx , out = out , fill_value = fill_value )
251
+ take_f (arr , _ensure_int64 (row_idx ),
252
+ _ensure_int64 (col_idx ), out = out ,
253
+ fill_value = fill_value )
258
254
return out
259
255
elif dtype_str in ('float64' , 'object' , 'datetime64[us]' ):
260
256
out = np .empty (out_shape , dtype = arr .dtype )
261
- take_f (arr , row_idx , col_idx , out = out , fill_value = fill_value )
257
+ take_f (arr , _ensure_int64 (row_idx ), _ensure_int64 (col_idx ), out = out ,
258
+ fill_value = fill_value )
262
259
return out
263
260
else :
264
261
return take_2d (take_2d (arr , row_idx , axis = 0 , fill_value = fill_value ),
@@ -277,10 +274,7 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
277
274
out_shape = tuple (out_shape )
278
275
279
276
if not isinstance (indexer , np .ndarray ):
280
- # Cython methods expects 32-bit integers
281
- indexer = np .array (indexer , dtype = np .int32 )
282
-
283
- indexer = _ensure_int64 (indexer )
277
+ indexer = np .array (indexer , dtype = np .int64 )
284
278
285
279
if dtype_str in ('int32' , 'int64' , 'bool' ):
286
280
if mask is None :
@@ -289,7 +283,7 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
289
283
290
284
if needs_masking :
291
285
# upcasting may be required
292
- result = arr . take ( indexer , axis = axis , out = out )
286
+ result = _ndtake ( arr , indexer , axis = axis , out = out )
293
287
result = _maybe_mask (result , mask , needs_masking , axis = axis ,
294
288
out_passed = out is not None ,
295
289
fill_value = fill_value )
@@ -298,13 +292,13 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
298
292
if out is None :
299
293
out = np .empty (out_shape , dtype = arr .dtype )
300
294
take_f = _get_take2d_function (dtype_str , axis = axis )
301
- take_f (arr , indexer , out = out , fill_value = fill_value )
295
+ take_f (arr , _ensure_int64 ( indexer ) , out = out , fill_value = fill_value )
302
296
return out
303
297
elif dtype_str in ('float64' , 'object' , 'datetime64[us]' ):
304
298
if out is None :
305
299
out = np .empty (out_shape , dtype = arr .dtype )
306
300
take_f = _get_take2d_function (dtype_str , axis = axis )
307
- take_f (arr , indexer , out = out , fill_value = fill_value )
301
+ take_f (arr , _ensure_int64 ( indexer ) , out = out , fill_value = fill_value )
308
302
return out
309
303
else :
310
304
if mask is None :
@@ -315,12 +309,15 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
315
309
if out is not None and arr .dtype != out .dtype :
316
310
arr = arr .astype (out .dtype )
317
311
318
- result = arr . take ( indexer , axis = axis , out = out )
312
+ result = _ndtake ( arr , indexer , axis = axis , out = out )
319
313
result = _maybe_mask (result , mask , needs_masking , axis = axis ,
320
314
out_passed = out is not None ,
321
315
fill_value = fill_value )
322
316
return result
323
317
318
+ def _ndtake (arr , indexer , axis = 0 , out = None ):
319
+ return arr .take (_ensure_platform_int (indexer ), axis = axis , out = out )
320
+
324
321
def mask_out_axis (arr , mask , axis , fill_value = np .nan ):
325
322
indexer = [slice (None )] * arr .ndim
326
323
indexer [axis ] = mask
@@ -334,7 +331,7 @@ def take_fast(arr, indexer, mask, needs_masking, axis=0, out=None,
334
331
needs_masking = needs_masking ,
335
332
axis = axis , fill_value = fill_value )
336
333
indexer = _ensure_platform_int (indexer )
337
- result = arr . take ( indexer , axis = axis , out = out )
334
+ result = _ndtake ( arr , indexer , axis = axis , out = out )
338
335
result = _maybe_mask (result , mask , needs_masking , axis = axis ,
339
336
out_passed = out is not None , fill_value = fill_value )
340
337
return result
@@ -727,9 +724,12 @@ def _ensure_int64(arr):
727
724
return np .array (arr , dtype = np .int64 )
728
725
729
726
def _ensure_platform_int (labels ):
730
- if labels .dtype != np .int_ : # pragma: no cover
731
- labels = labels .astype (np .int_ )
732
- return labels
727
+ try :
728
+ if labels .dtype != np .int_ : # pragma: no cover
729
+ labels = labels .astype (np .int_ )
730
+ return labels
731
+ except AttributeError :
732
+ return np .array (labels , dtype = np .int_ )
733
733
734
734
def _ensure_int32 (arr ):
735
735
if arr .dtype != np .int32 :
0 commit comments