Skip to content

Commit 863f1cf

Browse files
committed
BLD: more platform int fixes per #855
1 parent ff45745 commit 863f1cf

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

pandas/core/common.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,23 @@ def take_1d(arr, indexer, out=None, fill_value=np.nan):
191191

192192
n = len(indexer)
193193

194-
if not isinstance(indexer, np.ndarray):
195-
# Cython methods expects 32-bit integers
196-
indexer = np.array(indexer, dtype=np.int32)
197-
198194
indexer = _ensure_int64(indexer)
195+
199196
out_passed = out is not None
200197
take_f = _take1d_dict.get(dtype_str)
201198

202199
if dtype_str in ('int32', 'int64', 'bool'):
203200
try:
204201
if out is None:
205202
out = np.empty(n, dtype=arr.dtype)
206-
take_f(arr, indexer, out=out, fill_value=fill_value)
203+
take_f(arr, _ensure_int64(indexer), out=out, fill_value=fill_value)
207204
except ValueError:
208205
mask = indexer == -1
209206
if len(arr) == 0:
210207
if not out_passed:
211208
out = np.empty(n, dtype=arr.dtype)
212209
else:
213-
out = arr.take(indexer, out=out)
210+
out = _ndtake(arr, indexer, out=out)
214211
if mask.any():
215212
if out_passed:
216213
raise Exception('out with dtype %s does not support NA' %
@@ -220,9 +217,9 @@ def take_1d(arr, indexer, out=None, fill_value=np.nan):
220217
elif dtype_str in ('float64', 'object', 'datetime64[us]'):
221218
if out is None:
222219
out = np.empty(n, dtype=arr.dtype)
223-
take_f(arr, indexer, out=out, fill_value=fill_value)
220+
take_f(arr, _ensure_int64(indexer), out=out, fill_value=fill_value)
224221
else:
225-
out = arr.take(indexer, out=out)
222+
out = _ndtake(arr, indexer, out=out)
226223
mask = indexer == -1
227224
if mask.any():
228225
if out_passed:
@@ -239,9 +236,6 @@ def take_2d_multi(arr, row_idx, col_idx, fill_value=np.nan):
239236

240237
take_f = _get_take2d_function(dtype_str, axis='multi')
241238

242-
row_idx = _ensure_int64(row_idx)
243-
col_idx = _ensure_int64(col_idx)
244-
245239
out_shape = len(row_idx), len(col_idx)
246240

247241
if dtype_str in ('int32', 'int64', 'bool'):
@@ -254,11 +248,14 @@ def take_2d_multi(arr, row_idx, col_idx, fill_value=np.nan):
254248
fill_value=fill_value)
255249
else:
256250
out = np.empty(out_shape, dtype=arr.dtype)
257-
take_f(arr, row_idx, col_idx, out=out, fill_value=fill_value)
251+
take_f(arr, _ensure_int64(row_idx),
252+
_ensure_int64(col_idx), out=out,
253+
fill_value=fill_value)
258254
return out
259255
elif dtype_str in ('float64', 'object', 'datetime64[us]'):
260256
out = np.empty(out_shape, dtype=arr.dtype)
261-
take_f(arr, row_idx, col_idx, out=out, fill_value=fill_value)
257+
take_f(arr, _ensure_int64(row_idx), _ensure_int64(col_idx), out=out,
258+
fill_value=fill_value)
262259
return out
263260
else:
264261
return take_2d(take_2d(arr, row_idx, axis=0, fill_value=fill_value),
@@ -277,10 +274,7 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
277274
out_shape = tuple(out_shape)
278275

279276
if not isinstance(indexer, np.ndarray):
280-
# Cython methods expects 32-bit integers
281-
indexer = np.array(indexer, dtype=np.int32)
282-
283-
indexer = _ensure_int64(indexer)
277+
indexer = np.array(indexer, dtype=np.int64)
284278

285279
if dtype_str in ('int32', 'int64', 'bool'):
286280
if mask is None:
@@ -289,7 +283,7 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
289283

290284
if needs_masking:
291285
# upcasting may be required
292-
result = arr.take(indexer, axis=axis, out=out)
286+
result = _ndtake(arr, indexer, axis=axis, out=out)
293287
result = _maybe_mask(result, mask, needs_masking, axis=axis,
294288
out_passed=out is not None,
295289
fill_value=fill_value)
@@ -298,13 +292,13 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
298292
if out is None:
299293
out = np.empty(out_shape, dtype=arr.dtype)
300294
take_f = _get_take2d_function(dtype_str, axis=axis)
301-
take_f(arr, indexer, out=out, fill_value=fill_value)
295+
take_f(arr, _ensure_int64(indexer), out=out, fill_value=fill_value)
302296
return out
303297
elif dtype_str in ('float64', 'object', 'datetime64[us]'):
304298
if out is None:
305299
out = np.empty(out_shape, dtype=arr.dtype)
306300
take_f = _get_take2d_function(dtype_str, axis=axis)
307-
take_f(arr, indexer, out=out, fill_value=fill_value)
301+
take_f(arr, _ensure_int64(indexer), out=out, fill_value=fill_value)
308302
return out
309303
else:
310304
if mask is None:
@@ -315,12 +309,15 @@ def take_2d(arr, indexer, out=None, mask=None, needs_masking=None, axis=0,
315309
if out is not None and arr.dtype != out.dtype:
316310
arr = arr.astype(out.dtype)
317311

318-
result = arr.take(indexer, axis=axis, out=out)
312+
result = _ndtake(arr, indexer, axis=axis, out=out)
319313
result = _maybe_mask(result, mask, needs_masking, axis=axis,
320314
out_passed=out is not None,
321315
fill_value=fill_value)
322316
return result
323317

318+
def _ndtake(arr, indexer, axis=0, out=None):
319+
return arr.take(_ensure_platform_int(indexer), axis=axis, out=out)
320+
324321
def mask_out_axis(arr, mask, axis, fill_value=np.nan):
325322
indexer = [slice(None)] * arr.ndim
326323
indexer[axis] = mask
@@ -334,7 +331,7 @@ def take_fast(arr, indexer, mask, needs_masking, axis=0, out=None,
334331
needs_masking=needs_masking,
335332
axis=axis, fill_value=fill_value)
336333
indexer = _ensure_platform_int(indexer)
337-
result = arr.take(indexer, axis=axis, out=out)
334+
result = _ndtake(arr, indexer, axis=axis, out=out)
338335
result = _maybe_mask(result, mask, needs_masking, axis=axis,
339336
out_passed=out is not None, fill_value=fill_value)
340337
return result
@@ -727,9 +724,12 @@ def _ensure_int64(arr):
727724
return np.array(arr, dtype=np.int64)
728725

729726
def _ensure_platform_int(labels):
730-
if labels.dtype != np.int_: # pragma: no cover
731-
labels = labels.astype(np.int_)
732-
return labels
727+
try:
728+
if labels.dtype != np.int_: # pragma: no cover
729+
labels = labels.astype(np.int_)
730+
return labels
731+
except AttributeError:
732+
return np.array(labels, dtype=np.int_)
733733

734734
def _ensure_int32(arr):
735735
if arr.dtype != np.int32:

0 commit comments

Comments
 (0)