|
13 | 13 | import numpy as np
|
14 | 14 |
|
15 | 15 | from pandas._libs import lib, tslibs
|
16 |
| -import pandas.compat as compat |
17 |
| -from pandas.compat import PY36, iteritems |
18 |
| - |
19 | 16 | from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
|
20 |
| -from pandas.core.dtypes.common import ( |
21 |
| - is_array_like, is_bool_dtype, is_extension_array_dtype, is_integer) |
22 |
| -from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries |
| 17 | +from pandas import compat |
| 18 | +from pandas.compat import iteritems, PY2, PY36, OrderedDict |
| 19 | +from pandas.core.dtypes.generic import ABCSeries, ABCIndex, ABCIndexClass |
| 20 | +from pandas.core.dtypes.common import (is_integer, is_integer_dtype, |
| 21 | + is_bool_dtype, is_extension_array_dtype, |
| 22 | + is_array_like, |
| 23 | + is_float_dtype, is_object_dtype, |
| 24 | + is_categorical_dtype, is_numeric_dtype, |
| 25 | + is_scalar, ensure_platform_int) |
23 | 26 | from pandas.core.dtypes.inference import _iterable_not_string
|
24 | 27 | from pandas.core.dtypes.missing import isna, isnull, notnull # noqa
|
25 | 28 |
|
@@ -482,3 +485,83 @@ def f(x):
|
482 | 485 | f = mapper
|
483 | 486 |
|
484 | 487 | return f
|
| 488 | + |
| 489 | + |
| 490 | +def ensure_integer_dtype(arr, value): |
| 491 | + """ |
| 492 | + Ensure optimal dtype for :func:`searchsorted_integer` is returned. |
| 493 | +
|
| 494 | + Parameters |
| 495 | + ---------- |
| 496 | + arr : a numpy integer array |
| 497 | + value : a number or array of numbers |
| 498 | +
|
| 499 | + Returns |
| 500 | + ------- |
| 501 | + dtype : an numpy integer dtype |
| 502 | +
|
| 503 | + Raises |
| 504 | + ------ |
| 505 | + TypeError : if value is not a number |
| 506 | + """ |
| 507 | + value_arr = np.array([value]) if is_scalar(value) else np.array(value) |
| 508 | + |
| 509 | + if PY2 and not is_numeric_dtype(value_arr): |
| 510 | + # python 2 allows "a" < 1, avoid such nonsense |
| 511 | + msg = "value must be numeric, was type {}" |
| 512 | + raise TypeError(msg.format(value)) |
| 513 | + |
| 514 | + iinfo = np.iinfo(arr.dtype) |
| 515 | + if not ((value_arr < iinfo.min).any() or (value_arr > iinfo.max).any()): |
| 516 | + return arr.dtype |
| 517 | + else: |
| 518 | + return value_arr.dtype |
| 519 | + |
| 520 | + |
| 521 | +def searchsorted_integer(arr, value, side="left", sorter=None): |
| 522 | + """ |
| 523 | + searchsorted implementation, but only for integer arrays. |
| 524 | +
|
| 525 | + We get a speedup if the dtype of arr and value is the same. |
| 526 | +
|
| 527 | + See :func:`searchsorted` for a more general searchsorted implementation. |
| 528 | + """ |
| 529 | + if sorter is not None: |
| 530 | + sorter = ensure_platform_int(sorter) |
| 531 | + |
| 532 | + dtype = ensure_integer_dtype(arr, value) |
| 533 | + |
| 534 | + if is_integer(value) or is_integer_dtype(value): |
| 535 | + value = np.asarray(value, dtype=dtype) |
| 536 | + elif hasattr(value, 'is_integer') and value.is_integer(): |
| 537 | + # float 2.0 can be converted to int 2 for better speed, |
| 538 | + # but float 2.2 should *not* be converted to int 2 |
| 539 | + value = np.asarray(value, dtype=dtype) |
| 540 | + |
| 541 | + return np.searchsorted(arr, value, side=side, sorter=sorter) |
| 542 | + |
| 543 | + |
| 544 | +def searchsorted(arr, value, side="left", sorter=None): |
| 545 | + """ |
| 546 | + Find indices where elements should be inserted to maintain order. |
| 547 | +
|
| 548 | + Find the indices into a sorted array-like `arr` such that, if the |
| 549 | + corresponding elements in `value` were inserted before the indices, |
| 550 | + the order of `arr` would be preserved. |
| 551 | +
|
| 552 | + See :class:`IndexOpsMixin.searchsorted` for more details and examples. |
| 553 | + """ |
| 554 | + if sorter is not None: |
| 555 | + sorter = ensure_platform_int(sorter) |
| 556 | + |
| 557 | + if is_integer_dtype(arr): |
| 558 | + return searchsorted_integer(arr, value, side=side, sorter=sorter) |
| 559 | + elif (is_object_dtype(arr) or is_float_dtype(arr) or |
| 560 | + is_categorical_dtype(arr)): |
| 561 | + return arr.searchsorted(value, side=side, sorter=sorter) |
| 562 | + else: |
| 563 | + # fallback solution. E.g. arr is an array with dtype='datetime64[ns]' |
| 564 | + # and value is a pd.Timestamp, need to convert value |
| 565 | + from pandas.core.series import Series |
| 566 | + value = Series(value)._values |
| 567 | + return arr.searchsorted(value, side=side, sorter=sorter) |
0 commit comments