diff --git a/pandas/core/common.py b/pandas/core/common.py index f8f5928ca7d51..171ce9462452f 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -1271,14 +1271,23 @@ def _possibly_downcast_to_dtype(result, dtype): dtype = np.dtype(dtype) try: - # don't allow upcasts here (except if empty) + print dtype.kind, result.dtype.kind if dtype.kind == result.dtype.kind: if result.dtype.itemsize <= dtype.itemsize and np.prod(result.shape): return result if issubclass(dtype.type, np.floating): return result.astype(dtype) + + # a datetimelike + elif ((dtype.kind == 'M' and result.dtype.kind == 'i') or + dtype.kind == 'm'): + try: + result = result.astype(dtype) + except: + pass + elif dtype == np.bool_ or issubclass(dtype.type, np.integer): # if we don't have any elements, just astype it @@ -1309,13 +1318,6 @@ def _possibly_downcast_to_dtype(result, dtype): if (new_result == result).all(): return new_result - # a datetimelike - elif dtype.kind in ['M','m'] and result.dtype.kind in ['i']: - try: - result = result.astype(dtype) - except: - pass - except: pass diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index cb5dedc887bca..4b55b8cced559 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -1083,12 +1083,24 @@ def _try_cast(self, result, obj): def _cython_agg_general(self, how, numeric_only=True): output = {} for name, obj in self._iterate_slices(): - is_numeric = is_numeric_dtype(obj.dtype) + if is_numeric_dtype(obj.dtype): + obj = com.ensure_float(obj) + is_numeric = True + out_dtype = 'f%d' % obj.dtype.itemsize + values = obj.values + else: + is_numeric = issubclass(obj.dtype.type, (np.datetime64, + np.timedelta64)) + if is_numeric: + values = obj.view('int64') + else: + values = obj.astype(object) + if numeric_only and not is_numeric: continue try: - result, names = self.grouper.aggregate(obj.values, how) + result, names = self.grouper.aggregate(values, how) except AssertionError as e: raise GroupByError(str(e)) output[name] = self._try_cast(result, obj) @@ -2567,9 +2579,9 @@ def _cython_agg_blocks(self, how, numeric_only=True): data = data.get_numeric_data(copy=False) for block in data.blocks: - values = block._try_operate(block.values) + # TODO DAN if block.is_numeric: values = _algos.ensure_float64(values) diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 4077f468d8b1f..734287baaa50d 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -41,6 +41,7 @@ def _skip_if_mpl_not_installed(): except ImportError: raise nose.SkipTest("matplotlib not installed") + def commonSetUp(self): self.dateRange = bdate_range('1/1/2005', periods=250) self.stringIndex = Index([rands(8).upper() for x in range(250)]) @@ -603,6 +604,28 @@ def f(grp): e.name = None assert_series_equal(result,e) + # ...and with timedeltas + df1 = df.copy() + df1['D'] = pd.to_timedelta(['00:00:01', '00:00:02', '00:00:03', + '00:00:04', '00:00:05', '00:00:06', + '00:00:07']) + result = df1.groupby('A').apply(f)[['D']] + e = df1.groupby('A').first()[['D']] + e.loc['Pony'] = np.nan + print(type(result)) + print(type(e)) + assert_frame_equal(result, e) + + def f(grp): + if grp.name == 'Pony': + return None + return grp.iloc[0].loc['D'] + result = df1.groupby('A').apply(f)['D'] + e = df1.groupby('A').first()['D'].copy() + e.loc['Pony'] = np.nan + e.name = None + assert_series_equal(result, e) + def test_agg_api(self): # GH 6337 @@ -4365,6 +4388,19 @@ def test_index_label_overlaps_location(self): expected = ser.take([1, 3, 4]) assert_series_equal(actual, expected) + def test_groupby_methods_on_timedelta64(self): + df = self.df.copy().iloc[:4] + df['E'] = pd.to_timedelta(['00:00:01', '00:00:02', '00:00:03', '00:00:04']) + # DataFrameGroupBy + actual = df.groupby('A').mean()['E'] + expected = pd.to_timedelta(Series(['00:00:03', '00:00:02'], index=['bar', 'foo'], name='E')) + assert_series_equal(actual, expected) + + ser = df['E'] + # SeriesGroupBy + actual = ser.groupby(df['A']).mean() + assert_series_equal(actual, expected) + def test_groupby_selection_with_methods(self): # some methods which require DatetimeIndex rng = pd.date_range('2014', periods=len(self.df))