Skip to content

Commit 033863f

Browse files
committed
Merge pull request #5713 from jreback/transform_bug
BUG: Bug in groupby transform with a datetime-like grouper (GH5712)
2 parents 986bda2 + 76f98b6 commit 033863f

File tree

4 files changed

+47
-15
lines changed

4 files changed

+47
-15
lines changed

doc/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ Bug Fixes
823823
- Work around regression in numpy 1.7.0 which erroneously raises IndexError from ``ndarray.item`` (:issue:`5666`)
824824
- Bug in repeated indexing of object with resultant non-unique index (:issue:`5678`)
825825
- Bug in fillna with Series and a passed series/dict (:issue:`5703`)
826+
- Bug in groupby transform with a datetime-like grouper (:issue:`5712`)
826827

827828
pandas 0.12.0
828829
-------------

pandas/core/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,6 +2009,14 @@ def needs_i8_conversion(arr_or_dtype):
20092009
is_timedelta64_dtype(arr_or_dtype))
20102010

20112011

2012+
def is_numeric_dtype(arr_or_dtype):
2013+
if isinstance(arr_or_dtype, np.dtype):
2014+
tipo = arr_or_dtype.type
2015+
else:
2016+
tipo = arr_or_dtype.dtype.type
2017+
return (issubclass(tipo, (np.number, np.bool_))
2018+
and not issubclass(tipo, (np.datetime64, np.timedelta64)))
2019+
20122020
def is_float_dtype(arr_or_dtype):
20132021
if isinstance(arr_or_dtype, np.dtype):
20142022
tipo = arr_or_dtype.type

pandas/core/groupby.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import pandas.core.algorithms as algos
2020
import pandas.core.common as com
2121
from pandas.core.common import(_possibly_downcast_to_dtype, isnull,
22-
notnull, _DATELIKE_DTYPES)
22+
notnull, _DATELIKE_DTYPES, is_numeric_dtype,
23+
is_timedelta64_dtype, is_datetime64_dtype)
2324

2425
import pandas.lib as lib
26+
from pandas.lib import Timestamp
2527
import pandas.algos as _algos
2628
import pandas.hashtable as _hash
2729

@@ -257,6 +259,16 @@ def indices(self):
257259
""" dict {group name -> group indices} """
258260
return self.grouper.indices
259261

262+
def _get_index(self, name):
263+
""" safe get index """
264+
try:
265+
return self.indices[name]
266+
except:
267+
if isinstance(name, Timestamp):
268+
name = name.value
269+
return self.indices[name]
270+
raise
271+
260272
@property
261273
def name(self):
262274
if self._selection is None:
@@ -350,7 +362,7 @@ def get_group(self, name, obj=None):
350362
if obj is None:
351363
obj = self.obj
352364

353-
inds = self.indices[name]
365+
inds = self._get_index(name)
354366
return obj.take(inds, axis=self.axis, convert=False)
355367

356368
def __iter__(self):
@@ -676,7 +688,7 @@ def _try_cast(self, result, obj):
676688
def _cython_agg_general(self, how, numeric_only=True):
677689
output = {}
678690
for name, obj in self._iterate_slices():
679-
is_numeric = _is_numeric_dtype(obj.dtype)
691+
is_numeric = is_numeric_dtype(obj.dtype)
680692
if numeric_only and not is_numeric:
681693
continue
682694

@@ -714,7 +726,7 @@ def _python_agg_general(self, func, *args, **kwargs):
714726

715727
# since we are masking, make sure that we have a float object
716728
values = result
717-
if _is_numeric_dtype(values.dtype):
729+
if is_numeric_dtype(values.dtype):
718730
values = com.ensure_float(values)
719731

720732
output[name] = self._try_cast(values[mask], result)
@@ -1080,7 +1092,7 @@ def aggregate(self, values, how, axis=0):
10801092
raise NotImplementedError
10811093
out_shape = (self.ngroups,) + values.shape[1:]
10821094

1083-
if _is_numeric_dtype(values.dtype):
1095+
if is_numeric_dtype(values.dtype):
10841096
values = com.ensure_float(values)
10851097
is_numeric = True
10861098
else:
@@ -1474,6 +1486,15 @@ def __init__(self, index, grouper=None, name=None, level=None,
14741486
self.grouper = None # Try for sanity
14751487
raise AssertionError(errmsg)
14761488

1489+
# if we have a date/time-like grouper, make sure that we have Timestamps like
1490+
if getattr(self.grouper,'dtype',None) is not None:
1491+
if is_datetime64_dtype(self.grouper):
1492+
from pandas import to_datetime
1493+
self.grouper = to_datetime(self.grouper)
1494+
elif is_timedelta64_dtype(self.grouper):
1495+
from pandas import to_timedelta
1496+
self.grouper = to_timedelta(self.grouper)
1497+
14771498
def __repr__(self):
14781499
return 'Grouping(%s)' % self.name
14791500

@@ -1821,7 +1842,7 @@ def transform(self, func, *args, **kwargs):
18211842
# need to do a safe put here, as the dtype may be different
18221843
# this needs to be an ndarray
18231844
result = Series(result)
1824-
result.iloc[self.indices[name]] = res
1845+
result.iloc[self._get_index(name)] = res
18251846
result = result.values
18261847

18271848
# downcast if we can (and need)
@@ -1860,7 +1881,7 @@ def true_and_notnull(x, *args, **kwargs):
18601881
return b and notnull(b)
18611882

18621883
try:
1863-
indices = [self.indices[name] if true_and_notnull(group) else []
1884+
indices = [self._get_index(name) if true_and_notnull(group) else []
18641885
for name, group in self]
18651886
except ValueError:
18661887
raise TypeError("the filter must return a boolean result")
@@ -1921,7 +1942,7 @@ def _cython_agg_blocks(self, how, numeric_only=True):
19211942
for block in data.blocks:
19221943
values = block.values
19231944

1924-
is_numeric = _is_numeric_dtype(values.dtype)
1945+
is_numeric = is_numeric_dtype(values.dtype)
19251946

19261947
if numeric_only and not is_numeric:
19271948
continue
@@ -2412,7 +2433,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
24122433
res = path(group)
24132434

24142435
def add_indices():
2415-
indices.append(self.indices[name])
2436+
indices.append(self._get_index(name))
24162437

24172438
# interpret the result of the filter
24182439
if isinstance(res, (bool, np.bool_)):
@@ -2973,12 +2994,6 @@ def _reorder_by_uniques(uniques, labels):
29732994
}
29742995

29752996

2976-
def _is_numeric_dtype(dt):
2977-
typ = dt.type
2978-
return (issubclass(typ, (np.number, np.bool_))
2979-
and not issubclass(typ, (np.datetime64, np.timedelta64)))
2980-
2981-
29822997
def _intercept_function(func):
29832998
return _func_table.get(func, func)
29842999

pandas/tests/test_groupby.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,14 @@ def test_transform_broadcast(self):
627627
for idx in gp.index:
628628
assert_fp_equal(res.xs(idx), agged[idx])
629629

630+
def test_transform_bug(self):
631+
# GH 5712
632+
# transforming on a datetime column
633+
df = DataFrame(dict(A = Timestamp('20130101'), B = np.arange(5)))
634+
result = df.groupby('A')['B'].transform(lambda x: x.rank(ascending=False))
635+
expected = Series(np.arange(5,0,step=-1),name='B')
636+
assert_series_equal(result,expected)
637+
630638
def test_transform_multiple(self):
631639
grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])
632640

0 commit comments

Comments
 (0)