diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index e70c01ffcb12f..538831692fd67 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -17,7 +17,7 @@ from pandas.util.decorators import cache_readonly, Appender import pandas.core.algorithms as algos import pandas.core.common as com -from pandas.core.common import _possibly_downcast_to_dtype, notnull +from pandas.core.common import _possibly_downcast_to_dtype, isnull, notnull import pandas.lib as lib import pandas.algos as _algos @@ -1605,8 +1605,19 @@ def filter(self, func, dropna=True, *args, **kwargs): else: wrapper = lambda x: func(x, *args, **kwargs) - indexers = [self.obj.index.get_indexer(group.index) \ - if wrapper(group) else [] for _ , group in self] + # Interpret np.nan as False. + def true_and_notnull(x, *args, **kwargs): + b = wrapper(x, *args, **kwargs) + return b and notnull(b) + + try: + indexers = [self.obj.index.get_indexer(group.index) \ + if true_and_notnull(group) else [] \ + for _ , group in self] + except ValueError: + raise TypeError("the filter must return a boolean result") + except TypeError: + raise TypeError("the filter must return a boolean result") if len(indexers) == 0: filtered = self.obj.take([]) # because np.concatenate would fail @@ -2124,7 +2135,8 @@ def add_indexer(): add_indexer() else: if getattr(res,'ndim',None) == 1: - if res.ravel()[0]: + val = res.ravel()[0] + if val and notnull(val): add_indexer() else: diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index fec6460ea31f3..babe72e3ca106 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -2642,9 +2642,37 @@ def raise_if_sum_is_zero(x): s = pd.Series([-1,0,1,2]) grouper = s.apply(lambda x: x % 2) grouped = s.groupby(grouper) - self.assertRaises(ValueError, + self.assertRaises(TypeError, lambda: grouped.filter(raise_if_sum_is_zero)) + def test_filter_bad_shapes(self): + df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)}) + s = df['B'] + g_df = df.groupby('B') + g_s = s.groupby(s) + + f = lambda x: x + self.assertRaises(TypeError, lambda: g_df.filter(f)) + self.assertRaises(TypeError, lambda: g_s.filter(f)) + + f = lambda x: x == 1 + self.assertRaises(TypeError, lambda: g_df.filter(f)) + self.assertRaises(TypeError, lambda: g_s.filter(f)) + + f = lambda x: np.outer(x, x) + self.assertRaises(TypeError, lambda: g_df.filter(f)) + self.assertRaises(TypeError, lambda: g_s.filter(f)) + + def test_filter_nan_is_false(self): + df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)}) + s = df['B'] + g_df = df.groupby(df['B']) + g_s = s.groupby(s) + + f = lambda x: np.nan + assert_frame_equal(g_df.filter(f), df.loc[[]]) + assert_series_equal(g_s.filter(f), s[[]]) + def test_filter_against_workaround(self): np.random.seed(0) # Series of ints @@ -2697,6 +2725,29 @@ def test_filter_against_workaround(self): new_way = grouped.filter(lambda x: x['ints'].mean() > N/20) assert_frame_equal(new_way.sort_index(), old_way.sort_index()) + def test_filter_using_len(self): + # BUG GH4447 + df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)}) + grouped = df.groupby('B') + actual = grouped.filter(lambda x: len(x) > 2) + expected = DataFrame({'A': np.arange(2, 6), 'B': list('bbbb'), 'C': np.arange(2, 6)}, index=np.arange(2, 6)) + assert_frame_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = df.ix[[]] + assert_frame_equal(actual, expected) + + # Series have always worked properly, but we'll test anyway. + s = df['B'] + grouped = s.groupby(s) + actual = grouped.filter(lambda x: len(x) > 2) + expected = Series(4*['b'], index=np.arange(2, 6)) + assert_series_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = s[[]] + assert_series_equal(actual, expected) + def test_groupby_whitelist(self): from string import ascii_lowercase letters = np.array(list(ascii_lowercase))