Skip to content

Commit aeef3c2

Browse files
committed
ENH: groupby.apply for Categorical should preserve categories (closes #10138)
1 parent b32f218 commit aeef3c2

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

doc/source/whatsnew/v0.17.0.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ New features
2626
Other enhancements
2727
^^^^^^^^^^^^^^^^^^
2828

29+
- groupby.apply aggregation for Categorical now preserves categories (:issue:`10138`)
30+
2931
.. _whatsnew_0170.api:
3032

3133
Backwards incompatible API changes

pandas/core/groupby.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2940,7 +2940,8 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
29402940
cd = 'coerce'
29412941
else:
29422942
cd = True
2943-
return result.convert_objects(convert_dates=cd)
2943+
result = result.convert_objects(convert_dates=cd)
2944+
return self._reindex_output(result)
29442945

29452946
else:
29462947
# only coerce dates if we find at least 1 datetime

pandas/tests/test_groupby.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,34 @@ def get_stats(group):
25962596
result = self.df.groupby(cats).D.apply(get_stats)
25972597
self.assertEqual(result.index.names[0], 'C')
25982598

2599+
def test_apply_categorical_data(self):
2600+
# GH 10138
2601+
dense = Categorical(list('abc'))
2602+
# 'b' is in the categories but not in the list
2603+
missing = Categorical(list('aaa'), categories=['a', 'b'])
2604+
values = np.arange(len(dense))
2605+
df = DataFrame({'missing': missing,
2606+
'dense': dense,
2607+
'values': values})
2608+
grouped = df.groupby(['missing', 'dense'])
2609+
2610+
# missing category 'b' should still exist in the output index
2611+
idx = MultiIndex.from_product([['a', 'b'], ['a', 'b', 'c']],
2612+
names=['missing', 'dense'])
2613+
expected = DataFrame([0, 1, 2, np.nan, np.nan, np.nan],
2614+
index=idx,
2615+
columns=['values'])
2616+
2617+
assert_frame_equal(grouped.apply(lambda x: np.mean(x)), expected)
2618+
assert_frame_equal(grouped.mean(), expected)
2619+
assert_frame_equal(grouped.agg(np.mean), expected)
2620+
2621+
# but for transform we should still get back the original index
2622+
idx = MultiIndex.from_product([['a'], ['a', 'b', 'c']],
2623+
names=['missing', 'dense'])
2624+
expected = Series(1, index=idx)
2625+
assert_series_equal(grouped.apply(lambda x: 1), expected)
2626+
25992627
def test_apply_corner_cases(self):
26002628
# #535, can't use sliding iterator
26012629

0 commit comments

Comments
 (0)