Skip to content

Commit 849fac4

Browse files
committed
try using cache arguments to keep vectorization
1 parent 25efd37 commit 849fac4

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

pandas/core/groupby/groupby.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3935,10 +3935,38 @@ def _apply_to_column_groupbys(self, func):
39353935
return func(self)
39363936

39373937
def pct_change(self, periods=1, fill_method='pad', limit=None, freq=None):
3938-
"""Calcuate pct_change of each value to previous entry in group"""
3939-
return self.apply(lambda x: x.pct_change(periods=periods,
3940-
fill_method=fill_method,
3941-
limit=limit, freq=freq))
3938+
"""Calculate pct_change of each value to previous entry in group"""
3939+
grouper = self.grouper
3940+
cache_exist = getattr(grouper, '_cache', False)
3941+
if cache_exist:
3942+
in_cache = True if 'is_monotonic' in cache_exist.keys() else False
3943+
else:
3944+
in_cache = False
3945+
m = grouper.is_monotonic if in_cache else False
3946+
if not m or fill_method is None:
3947+
return self.apply(lambda x: x.pct_change(periods=periods,
3948+
fill_method=fill_method,
3949+
limit=limit, freq=freq))
3950+
3951+
def get_invalid_index(x):
3952+
if periods == 0:
3953+
return x
3954+
elif periods > 0:
3955+
ax = Index(np.arange(min(x), min(x) + periods))
3956+
return ax
3957+
elif periods < 0:
3958+
ax = Index(np.arange(max(x), max(x) + periods, -1))
3959+
return ax
3960+
3961+
filled = getattr(self, fill_method)(limit=limit)
3962+
shifted = filled.shift(periods=periods, freq=freq)
3963+
pct_change = (filled / shifted) - 1
3964+
3965+
invalid_index = Index([])
3966+
for i in [get_invalid_index(v) for k, v in self.indices.items()]:
3967+
invalid_index = invalid_index.union(i)
3968+
pct_change.iloc[invalid_index] = np.nan
3969+
return pct_change
39423970

39433971

39443972
class NDFrameGroupBy(GroupBy):

pandas/tests/groupby/test_transform.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -723,30 +723,37 @@ def interweave(list_obj):
723723

724724
@pytest.mark.parametrize("test_series", [True, False])
725725
@pytest.mark.parametrize("shuffle", [True, False])
726+
@pytest.mark.parametrize("activate_cache", [True, False])
726727
@pytest.mark.parametrize("periods,fill_method,limit", [
727728
(1, 'ffill', None), (1, 'ffill', 1),
728729
(1, 'bfill', None), (1, 'bfill', 1),
729730
(-1, 'ffill', None), (-1, 'ffill', 1),
730-
(-1, 'bfill', None), (-1, 'bfill', 1)])
731-
def test_pct_change(test_series, shuffle, periods, fill_method, limit):
732-
vals = [3, np.nan, 1, 2, 4, 10, np.nan, np.nan]
731+
(-1, 'bfill', None), (-1, 'bfill', 1),
732+
(-1, None, None), (-1, None, 1),
733+
(-1, None, None), (-1, None, 1)
734+
])
735+
def test_pct_change(test_series, shuffle, activate_cache, periods, fill_method, limit):
736+
vals = [3, np.nan, 1, 2, 4, 10, np.nan, 9]
733737
keys = ['a', 'b']
734-
key_v = [k for j in list(map(lambda x: [x] * len(vals), keys)) for k in j]
738+
key_v = np.repeat(keys, len(vals))
735739
df = DataFrame({'key': key_v, 'vals': vals * 2})
736740
if shuffle:
737741
order = np.random.RandomState(seed=42).permutation(len(df))
738742
df = df.reindex(order).reset_index(drop=True)
739743

740744
manual_apply = []
741745
for k in keys:
742-
subgroup = Series(df.loc[df.key == k, 'vals'].values)
743-
manual_apply.append(subgroup.pct_change(periods=periods,
744-
fill_method=fill_method,
745-
limit=limit))
746-
exp_vals = pd.concat(manual_apply).reset_index(drop=True)
747-
exp = pd.DataFrame(exp_vals, columns=['A'])
746+
ind = df.loc[df.key == k, 'vals']
747+
manual_apply.append(ind.pct_change(periods=periods,
748+
fill_method=fill_method,
749+
limit=limit))
750+
exp_vals = pd.concat(manual_apply, ignore_index=True)
751+
exp = pd.DataFrame(exp_vals.values, columns=['A'])
748752
grp = df.groupby('key')
749753

754+
if activate_cache:
755+
grp.grouper.is_monotonic
756+
750757
def get_result(grp_obj):
751758
return grp_obj.pct_change(periods=periods,
752759
fill_method=fill_method,
@@ -763,7 +770,7 @@ def get_result(grp_obj):
763770
tm.assert_series_equal(result, exp)
764771
else:
765772
result = get_result(grp)
766-
result.reset_index(drop=True, inplace=True)
773+
result = result.reset_index(drop=True)
767774
result.columns = ['A']
768775
tm.assert_frame_equal(result, exp)
769776

0 commit comments

Comments
 (0)