Skip to content

BUG: Bug in groupby transform with a datetime-like grouper (GH5712) #5713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 16, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ Bug Fixes
- Work around regression in numpy 1.7.0 which erroneously raises IndexError from ``ndarray.item`` (:issue:`5666`)
- Bug in repeated indexing of object with resultant non-unique index (:issue:`5678`)
- Bug in fillna with Series and a passed series/dict (:issue:`5703`)
- Bug in groupby transform with a datetime-like grouper (:issue:`5712`)

pandas 0.12.0
-------------
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,14 @@ def needs_i8_conversion(arr_or_dtype):
is_timedelta64_dtype(arr_or_dtype))


def is_numeric_dtype(arr_or_dtype):
if isinstance(arr_or_dtype, np.dtype):
tipo = arr_or_dtype.type
else:
tipo = arr_or_dtype.dtype.type
return (issubclass(tipo, (np.number, np.bool_))
and not issubclass(tipo, (np.datetime64, np.timedelta64)))

def is_float_dtype(arr_or_dtype):
if isinstance(arr_or_dtype, np.dtype):
tipo = arr_or_dtype.type
Expand Down
45 changes: 30 additions & 15 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import pandas.core.algorithms as algos
import pandas.core.common as com
from pandas.core.common import(_possibly_downcast_to_dtype, isnull,
notnull, _DATELIKE_DTYPES)
notnull, _DATELIKE_DTYPES, is_numeric_dtype,
is_timedelta64_dtype, is_datetime64_dtype)

import pandas.lib as lib
from pandas.lib import Timestamp
import pandas.algos as _algos
import pandas.hashtable as _hash

Expand Down Expand Up @@ -257,6 +259,16 @@ def indices(self):
""" dict {group name -> group indices} """
return self.grouper.indices

def _get_index(self, name):
""" safe get index """
try:
return self.indices[name]
except:
if isinstance(name, Timestamp):
name = name.value
return self.indices[name]
raise

@property
def name(self):
if self._selection is None:
Expand Down Expand Up @@ -350,7 +362,7 @@ def get_group(self, name, obj=None):
if obj is None:
obj = self.obj

inds = self.indices[name]
inds = self._get_index(name)
return obj.take(inds, axis=self.axis, convert=False)

def __iter__(self):
Expand Down Expand Up @@ -676,7 +688,7 @@ def _try_cast(self, result, obj):
def _cython_agg_general(self, how, numeric_only=True):
output = {}
for name, obj in self._iterate_slices():
is_numeric = _is_numeric_dtype(obj.dtype)
is_numeric = is_numeric_dtype(obj.dtype)
if numeric_only and not is_numeric:
continue

Expand Down Expand Up @@ -714,7 +726,7 @@ def _python_agg_general(self, func, *args, **kwargs):

# since we are masking, make sure that we have a float object
values = result
if _is_numeric_dtype(values.dtype):
if is_numeric_dtype(values.dtype):
values = com.ensure_float(values)

output[name] = self._try_cast(values[mask], result)
Expand Down Expand Up @@ -1080,7 +1092,7 @@ def aggregate(self, values, how, axis=0):
raise NotImplementedError
out_shape = (self.ngroups,) + values.shape[1:]

if _is_numeric_dtype(values.dtype):
if is_numeric_dtype(values.dtype):
values = com.ensure_float(values)
is_numeric = True
else:
Expand Down Expand Up @@ -1474,6 +1486,15 @@ def __init__(self, index, grouper=None, name=None, level=None,
self.grouper = None # Try for sanity
raise AssertionError(errmsg)

# if we have a date/time-like grouper, make sure that we have Timestamps like
if getattr(self.grouper,'dtype',None) is not None:
if is_datetime64_dtype(self.grouper):
from pandas import to_datetime
self.grouper = to_datetime(self.grouper)
elif is_timedelta64_dtype(self.grouper):
from pandas import to_timedelta
self.grouper = to_timedelta(self.grouper)

def __repr__(self):
return 'Grouping(%s)' % self.name

Expand Down Expand Up @@ -1821,7 +1842,7 @@ def transform(self, func, *args, **kwargs):
# need to do a safe put here, as the dtype may be different
# this needs to be an ndarray
result = Series(result)
result.iloc[self.indices[name]] = res
result.iloc[self._get_index(name)] = res
result = result.values

# downcast if we can (and need)
Expand Down Expand Up @@ -1860,7 +1881,7 @@ def true_and_notnull(x, *args, **kwargs):
return b and notnull(b)

try:
indices = [self.indices[name] if true_and_notnull(group) else []
indices = [self._get_index(name) if true_and_notnull(group) else []
for name, group in self]
except ValueError:
raise TypeError("the filter must return a boolean result")
Expand Down Expand Up @@ -1921,7 +1942,7 @@ def _cython_agg_blocks(self, how, numeric_only=True):
for block in data.blocks:
values = block.values

is_numeric = _is_numeric_dtype(values.dtype)
is_numeric = is_numeric_dtype(values.dtype)

if numeric_only and not is_numeric:
continue
Expand Down Expand Up @@ -2412,7 +2433,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
res = path(group)

def add_indices():
indices.append(self.indices[name])
indices.append(self._get_index(name))

# interpret the result of the filter
if isinstance(res, (bool, np.bool_)):
Expand Down Expand Up @@ -2973,12 +2994,6 @@ def _reorder_by_uniques(uniques, labels):
}


def _is_numeric_dtype(dt):
typ = dt.type
return (issubclass(typ, (np.number, np.bool_))
and not issubclass(typ, (np.datetime64, np.timedelta64)))


def _intercept_function(func):
return _func_table.get(func, func)

Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,14 @@ def test_transform_broadcast(self):
for idx in gp.index:
assert_fp_equal(res.xs(idx), agged[idx])

def test_transform_bug(self):
# GH 5712
# transforming on a datetime column
df = DataFrame(dict(A = Timestamp('20130101'), B = np.arange(5)))
result = df.groupby('A')['B'].transform(lambda x: x.rank(ascending=False))
expected = Series(np.arange(5,0,step=-1),name='B')
assert_series_equal(result,expected)

def test_transform_multiple(self):
grouped = self.ts.groupby([lambda x: x.year, lambda x: x.month])

Expand Down