Skip to content

Commit 1d9bc57

Browse files
haydcpcloud
authored andcommitted
ENH nlargest and nsmallest Series methods
1 parent 1f34b47 commit 1d9bc57

File tree

6 files changed

+234
-36
lines changed

6 files changed

+234
-36
lines changed

doc/source/v0.13.1.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ API changes
128128
import pandas.core.common as com
129129
com.array_equivalent(np.array([0, np.nan]), np.array([0, np.nan]))
130130
np.array_equal(np.array([0, np.nan]), np.array([0, np.nan]))
131+
- Add nsmallest and nlargest Series methods (:issue:`3960`)
131132

132133
- ``DataFrame.apply`` will use the ``reduce`` argument to determine whether a
133134
``Series`` or a ``DataFrame`` should be returned when the ``DataFrame`` is

pandas/algos.pyx

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ from numpy cimport NPY_FLOAT16 as NPY_float16
2121
from numpy cimport NPY_FLOAT32 as NPY_float32
2222
from numpy cimport NPY_FLOAT64 as NPY_float64
2323

24+
from numpy cimport (int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
25+
uint32_t, uint64_t, float16_t, float32_t, float64_t)
26+
2427
int8 = np.dtype(np.int8)
2528
int16 = np.dtype(np.int16)
2629
int32 = np.dtype(np.int32)
@@ -736,16 +739,34 @@ def _check_minp(win, minp, N):
736739
# Physical description: 366 p.
737740
# Series: Prentice-Hall Series in Automatic Computation
738741

739-
def kth_smallest(ndarray[double_t] a, Py_ssize_t k):
740-
cdef:
741-
Py_ssize_t i,j,l,m,n
742-
double_t x, t
742+
ctypedef fused kth_type:
743+
int8_t
744+
int16_t
745+
int32_t
746+
int64_t
743747

744-
n = len(a)
748+
uint8_t
749+
uint16_t
750+
uint32_t
751+
uint64_t
745752

746-
l = 0
747-
m = n-1
748-
while (l<m):
753+
float32_t
754+
float64_t
755+
756+
757+
cdef void swap_kth(kth_type *a, kth_type *b):
758+
cdef kth_type t
759+
t = a[0]
760+
a[0] = b[0]
761+
b[0] = t
762+
763+
764+
cpdef kth_type kth_smallest(kth_type[:] a, Py_ssize_t k):
765+
cdef:
766+
Py_ssize_t i, j, l = 0, n = a.size, m = n - 1
767+
kth_type x
768+
769+
while l < m:
749770
x = a[k]
750771
i = l
751772
j = m
@@ -754,9 +775,7 @@ def kth_smallest(ndarray[double_t] a, Py_ssize_t k):
754775
while a[i] < x: i += 1
755776
while x < a[j]: j -= 1
756777
if i <= j:
757-
t = a[i]
758-
a[i] = a[j]
759-
a[j] = t
778+
swap_kth(&a[i], &a[j])
760779
i += 1; j -= 1
761780

762781
if i > j: break
@@ -765,6 +784,7 @@ def kth_smallest(ndarray[double_t] a, Py_ssize_t k):
765784
if k < i: m = j
766785
return a[k]
767786

787+
768788
cdef inline kth_smallest_c(float64_t* a, Py_ssize_t k, Py_ssize_t n):
769789
cdef:
770790
Py_ssize_t i,j,l,m
@@ -781,9 +801,7 @@ cdef inline kth_smallest_c(float64_t* a, Py_ssize_t k, Py_ssize_t n):
781801
while a[i] < x: i += 1
782802
while x < a[j]: j -= 1
783803
if i <= j:
784-
t = a[i]
785-
a[i] = a[j]
786-
a[j] = t
804+
swap_kth(&a[i], &a[j])
787805
i += 1; j -= 1
788806

789807
if i > j: break
@@ -793,22 +811,22 @@ cdef inline kth_smallest_c(float64_t* a, Py_ssize_t k, Py_ssize_t n):
793811
return a[k]
794812

795813

796-
def median(ndarray arr):
814+
cpdef kth_type median(kth_type[:] arr):
797815
'''
798816
A faster median
799817
'''
800-
cdef int n = len(arr)
818+
cdef Py_ssize_t n = arr.size
801819

802-
if len(arr) == 0:
820+
if n == 0:
803821
return np.NaN
804822

805823
arr = arr.copy()
806824

807825
if n % 2:
808-
return kth_smallest(arr, n / 2)
826+
return kth_smallest(arr, n // 2)
809827
else:
810-
return (kth_smallest(arr, n / 2) +
811-
kth_smallest(arr, n / 2 - 1)) / 2
828+
return (kth_smallest(arr, n // 2) +
829+
kth_smallest(arr, n // 2 - 1)) / 2
812830

813831

814832
# -------------- Min, Max subsequence
@@ -2226,7 +2244,7 @@ cdef inline float64_t _median_linear(float64_t* a, int n):
22262244

22272245

22282246
if n % 2:
2229-
result = kth_smallest_c(a, n / 2, n)
2247+
result = kth_smallest_c( a, n / 2, n)
22302248
else:
22312249
result = (kth_smallest_c(a, n / 2, n) +
22322250
kth_smallest_c(a, n / 2 - 1, n)) / 2

pandas/core/series.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# pylint: disable=E1101,E1103
77
# pylint: disable=W0703,W0622,W0613,W0201
88

9-
import operator
109
import types
1110
import warnings
1211

@@ -15,29 +14,24 @@
1514
import numpy.ma as ma
1615

1716
from pandas.core.common import (isnull, notnull, _is_bool_indexer,
18-
_default_index, _maybe_promote, _maybe_upcast,
19-
_asarray_tuplesafe, is_integer_dtype,
20-
_NS_DTYPE, _TD_DTYPE,
21-
_infer_dtype_from_scalar, is_list_like,
22-
_values_from_object,
17+
_default_index, _maybe_upcast,
18+
_asarray_tuplesafe, _infer_dtype_from_scalar,
19+
is_list_like, _values_from_object,
2320
_possibly_cast_to_datetime, _possibly_castable,
24-
_possibly_convert_platform,
25-
_try_sort,
21+
_possibly_convert_platform, _try_sort,
2622
ABCSparseArray, _maybe_match_name,
2723
_ensure_object, SettingWithCopyError)
2824
from pandas.core.index import (Index, MultiIndex, InvalidIndexError,
2925
_ensure_index)
30-
from pandas.core.indexing import (
31-
_check_bool_indexer,
32-
_is_index_slice, _maybe_convert_indices)
26+
from pandas.core.indexing import _check_bool_indexer, _maybe_convert_indices
3327
from pandas.core import generic, base
3428
from pandas.core.internals import SingleBlockManager
3529
from pandas.core.categorical import Categorical
3630
from pandas.tseries.index import DatetimeIndex
3731
from pandas.tseries.period import PeriodIndex, Period
3832
from pandas import compat
3933
from pandas.util.terminal import get_terminal_size
40-
from pandas.compat import zip, lzip, u, OrderedDict
34+
from pandas.compat import zip, u, OrderedDict
4135

4236
import pandas.core.array as pa
4337
import pandas.core.ops as ops
@@ -46,7 +40,7 @@
4640
import pandas.core.datetools as datetools
4741
import pandas.core.format as fmt
4842
import pandas.core.nanops as nanops
49-
from pandas.util.decorators import Appender, Substitution, cache_readonly
43+
from pandas.util.decorators import Appender, cache_readonly
5044

5145
import pandas.lib as lib
5246
import pandas.tslib as tslib
@@ -1705,7 +1699,17 @@ def _try_kind_sort(arr):
17051699
good = ~bad
17061700
idx = pa.arange(len(self))
17071701

1708-
argsorted = _try_kind_sort(arr[good])
1702+
def _try_kind_sort(arr, kind='mergesort'):
1703+
# easier to ask forgiveness than permission
1704+
try:
1705+
# if kind==mergesort, it can fail for object dtype
1706+
return arr.argsort(kind=kind)
1707+
except TypeError:
1708+
# stable sort not available for object dtype
1709+
# uses the argsort default quicksort
1710+
return arr.argsort(kind='quicksort')
1711+
1712+
argsorted = _try_kind_sort(arr[good], kind=kind)
17091713

17101714
if not ascending:
17111715
argsorted = argsorted[::-1]
@@ -1728,6 +1732,51 @@ def _try_kind_sort(arr):
17281732
else:
17291733
return result.__finalize__(self)
17301734

1735+
def nlargest(self, n=5, take_last=False):
1736+
'''
1737+
Returns the largest n rows:
1738+
1739+
May be faster than .order(ascending=False).head(n).
1740+
1741+
'''
1742+
# TODO remove need for dropna ?
1743+
dropped = self.dropna()
1744+
1745+
from pandas.tools.util import nlargest
1746+
1747+
if dropped.dtype == object:
1748+
try:
1749+
dropped = dropped.astype(float)
1750+
except (NotImplementedError, TypeError):
1751+
return dropped.order(ascending=False).head(n)
1752+
1753+
inds = nlargest(dropped.values, n, take_last)
1754+
if len(inds) == 0:
1755+
# TODO remove this special case
1756+
return dropped[[]]
1757+
return dropped.iloc[inds]
1758+
1759+
def nsmallest(self, n=5, take_last=False):
1760+
'''
1761+
Returns the smallest n rows.
1762+
1763+
May be faster than .order().head(n).
1764+
1765+
'''
1766+
# TODO remove need for dropna ?
1767+
dropped = self.dropna()
1768+
1769+
from pandas.tools.util import nsmallest
1770+
try:
1771+
inds = nsmallest(dropped.values, n, take_last)
1772+
except NotImplementedError:
1773+
return dropped.order().head(n)
1774+
1775+
if len(inds) == 0:
1776+
# TODO remove this special case
1777+
return dropped[[]]
1778+
return dropped.iloc[inds]
1779+
17311780
def sortlevel(self, level=0, ascending=True, sort_remaining=True):
17321781
"""
17331782
Sort Series with MultiIndex by chosen level. Data will be

pandas/tests/test_series.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3998,6 +3998,39 @@ def test_order(self):
39983998
ordered = ts.order(ascending=False, na_position='first')
39993999
assert_almost_equal(expected, ordered.valid().values)
40004000

4001+
def test_nsmallest_nlargest(self):
4002+
# float, int, datetime64 (use i8), timedelts64 (same),
4003+
# object that are numbers, object that are strings
4004+
4005+
s_list = [Series([3, 2, 1, 2, 5]),
4006+
Series([3., 2., 1., 2., 5.]),
4007+
Series([3., 2, 1, 2, 5], dtype='object'),
4008+
Series([3., 2, 1, 2, '5'], dtype='object'),
4009+
Series(pd.to_datetime(['2003', '2002', '2001', '2002', '2005']))]
4010+
4011+
for s in s_list:
4012+
4013+
assert_series_equal(s.nsmallest(2), s.iloc[[2, 1]])
4014+
assert_series_equal(s.nsmallest(2, take_last=True), s.iloc[[2, 3]])
4015+
4016+
assert_series_equal(s.nlargest(3), s.iloc[[4, 0, 1]])
4017+
assert_series_equal(s.nlargest(3, take_last=True), s.iloc[[4, 0, 3]])
4018+
4019+
empty = s.iloc[0:0]
4020+
assert_series_equal(s.nsmallest(0), empty)
4021+
assert_series_equal(s.nsmallest(-1), empty)
4022+
assert_series_equal(s.nlargest(0), empty)
4023+
assert_series_equal(s.nlargest(-1), empty)
4024+
4025+
assert_series_equal(s.nsmallest(len(s)), s.order())
4026+
assert_series_equal(s.nsmallest(len(s) + 1), s.order())
4027+
assert_series_equal(s.nlargest(len(s)), s.iloc[[4, 0, 1, 3, 2]])
4028+
assert_series_equal(s.nlargest(len(s) + 1), s.iloc[[4, 0, 1, 3, 2]])
4029+
4030+
s = Series([3., np.nan, 1, 2, 5])
4031+
assert_series_equal(s.nlargest(), s.iloc[[4, 0, 3, 2]])
4032+
assert_series_equal(s.nsmallest(), s.iloc[[2, 3, 0, 4]])
4033+
40014034
def test_rank(self):
40024035
from pandas.compat.scipy import rankdata
40034036

pandas/tools/util.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pandas.compat import reduce
22
from pandas.core.index import Index
33
import numpy as np
4+
from pandas import algos
5+
import pandas.core.common as com
6+
47

58
def match(needles, haystack):
69
haystack = Index(haystack)
@@ -17,7 +20,7 @@ def cartesian_product(X):
1720
--------
1821
>>> cartesian_product([list('ABC'), [1, 2]])
1922
[array(['A', 'A', 'B', 'B', 'C', 'C'], dtype='|S1'),
20-
array([1, 2, 1, 2, 1, 2])]
23+
array([1, 2, 1, 2, 1, 2])]
2124
2225
'''
2326

@@ -43,3 +46,68 @@ def compose(*funcs):
4346
"""Compose 2 or more callables"""
4447
assert len(funcs) > 1, 'At least 2 callables must be passed to compose'
4548
return reduce(_compose2, funcs)
49+
50+
51+
_dtype_map = {'datetime64[ns]': 'int64', 'int64': 'int64',
52+
'float64': 'float64'}
53+
54+
55+
def nsmallest(arr, n=5, take_last=False):
56+
'''
57+
Find the indices of the n smallest values of a numpy array.
58+
59+
Note: Fails silently with NaN.
60+
61+
'''
62+
if n <= 0:
63+
return np.array([]) # empty
64+
elif n >= len(arr):
65+
n = len(arr)
66+
67+
if arr.dtype == object:
68+
# just sort and take n
69+
return arr.argsort(kind='mergesort')[:n]
70+
71+
try:
72+
dtype = _dtype_map[str(arr.dtype)]
73+
except KeyError:
74+
raise NotImplementedError("Not implemented for %s dtype, "
75+
"perhaps convert to int64 or float64, "
76+
"or use .order().head(n)") % arr.dtype
77+
78+
arr = arr.view(dtype)
79+
80+
if take_last:
81+
arr = arr[::-1]
82+
83+
kth_val = algos.kth_smallest(arr.copy(), n - 1)
84+
85+
ns, = np.nonzero(arr <= kth_val)
86+
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
87+
88+
if take_last:
89+
# reverse indices
90+
return len(arr) - 1 - inds
91+
return inds
92+
93+
94+
def nlargest(arr, n=5, take_last=False):
95+
'''
96+
Find the indices of the n largest values of a numpy array.
97+
98+
Note: Fails silently with NaN.
99+
100+
'''
101+
if n <= 0:
102+
return np.array([]) # empty
103+
104+
n = min(n, len(arr))
105+
106+
if arr.dtype == object:
107+
try:
108+
arr = arr.astype(float)
109+
except:
110+
raise TypeError("An object array must convert to float.")
111+
112+
arr = -arr.view(_dtype_map[str(arr.dtype)])
113+
return nsmallest(arr, n, take_last=take_last)

0 commit comments

Comments
 (0)