Skip to content

Commit 64097cc

Browse files
committed
BUG,TST: Remove case where vectorization fails in pct_change groupby method. Incorporate CR suggestion.
1 parent eaede34 commit 64097cc

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

pandas/core/groupby/groupby.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,10 +2070,24 @@ def shift(self, periods=1, freq=None, axis=0):
20702070
def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None,
20712071
axis=0):
20722072
"""Calcuate pct_change of each value to previous entry in group"""
2073-
return self.apply(lambda x: x.pct_change(periods=periods,
2074-
fill_method=fill_method,
2075-
limit=limit, freq=freq,
2076-
axis=axis))
2073+
if freq is not None or axis != 0:
2074+
return self.apply(lambda x: x.pct_change(periods=periods,
2075+
fill_method=fill_method,
2076+
limit=limit, freq=freq,
2077+
axis=axis))
2078+
if fill_method:
2079+
new = DataFrameGroupBy(self._obj_with_exclusions,
2080+
grouper=self.grouper)
2081+
new.obj = getattr(new, fill_method)(limit=limit)
2082+
new._reset_cache()
2083+
else:
2084+
new = self
2085+
2086+
obj = new.obj.drop(self.grouper.names, axis=1)
2087+
shifted = new.shift(periods=periods, freq=freq).\
2088+
drop(self.grouper.names, axis=1)
2089+
return (obj / shifted) - 1
2090+
20772091

20782092
@Substitution(name='groupby')
20792093
@Appender(_doc_template)
@@ -3936,9 +3950,15 @@ def _apply_to_column_groupbys(self, func):
39363950

39373951
def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None):
39383952
"""Calcuate pct_change of each value to previous entry in group"""
3939-
return self.apply(lambda x: x.pct_change(periods=periods,
3940-
fill_method=fill_method,
3941-
limit=limit, freq=freq))
3953+
if fill_method:
3954+
new = SeriesGroupBy(self.obj, grouper=self.grouper)
3955+
new.obj = getattr(new, fill_method)(limit=limit)
3956+
new._reset_cache()
3957+
else:
3958+
new = self
3959+
3960+
shifted = new.shift(periods=periods, freq=freq)
3961+
return (new.obj / shifted) - 1
39423962

39433963

39443964
class NDFrameGroupBy(GroupBy):

pandas/tests/groupby/test_transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def interweave(list_obj):
730730
(-1, 'bfill', None), (-1, 'bfill', 1),
731731
])
732732
def test_pct_change(test_series, shuffle, periods, fill_method, limit):
733-
vals = [3, np.nan, 1, 2, 4, 10, np.nan, 9]
733+
vals = [3, np.nan, 1, 2, 4, 10, np.nan, 4]
734734
keys = ['a', 'b']
735735
key_v = np.repeat(keys, len(vals))
736736
df = DataFrame({'key': key_v, 'vals': vals * 2})
@@ -765,7 +765,10 @@ def get_result(grp_obj):
765765
else:
766766
result = get_result(grp)
767767
result = result.reset_index(drop=True)
768-
result.columns = ['A']
768+
df.insert(0, 'A', result)
769+
result = df.sort_values(by='key')
770+
result = result.loc[:, ['A']]
771+
result = result.reset_index(drop=True)
769772
tm.assert_frame_equal(result, exp)
770773

771774

0 commit comments

Comments
 (0)