From f4f7117f8f182f22a32e3595f06d1cc498fea1ff Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 21 Apr 2021 17:57:14 -0700 Subject: [PATCH] REF: simplify ohlc --- pandas/core/groupby/generic.py | 13 +------ pandas/core/groupby/groupby.py | 38 ++++++++++++-------- pandas/tests/resample/test_datetime_index.py | 8 +++-- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4a721ae0d4bf6..1c1bde1f43c3b 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -899,10 +899,6 @@ def count(self) -> Series: ) return self._reindex_output(result, fill_value=0) - def _apply_to_column_groupbys(self, func): - """ return a pass thru """ - return func(self) - def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None): """Calculate pct_change of each value to previous entry in group""" # TODO: Remove this conditional when #23918 is fixed @@ -1094,6 +1090,7 @@ def _cython_agg_general( def _cython_agg_manager( self, how: str, alt=None, numeric_only: bool = True, min_count: int = -1 ) -> Manager2D: + # Note: we never get here with how="ohlc"; that goes through SeriesGroupBy data: Manager2D = self._get_data_to_aggregate() @@ -1184,13 +1181,6 @@ def array_func(values: ArrayLike) -> ArrayLike: # generally if we have numeric_only=False # and non-applicable functions # try to python agg - - if alt is None: - # we cannot perform the operation - # in an alternate way, exclude the block - assert how == "ohlc" - raise - result = py_fallback(values) return cast_agg_result(result, values, how) @@ -1198,7 +1188,6 @@ def array_func(values: ArrayLike) -> ArrayLike: # TypeError -> we may have an exception in trying to aggregate # continue and exclude the block - # NotImplementedError -> "ohlc" with wrong dtype new_mgr = data.grouped_reduce(array_func, ignore_failures=True) if not len(new_mgr): diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f2fffe4c3741c..9d805f37c74d8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1109,20 +1109,10 @@ def _cython_agg_general( result = self.grouper._cython_operation( "aggregate", obj._values, how, axis=0, min_count=min_count ) - - if how == "ohlc": - # e.g. ohlc - agg_names = ["open", "high", "low", "close"] - assert len(agg_names) == result.shape[1] - for result_column, result_name in zip(result.T, agg_names): - key = base.OutputKey(label=result_name, position=idx) - output[key] = result_column - idx += 1 - else: - assert result.ndim == 1 - key = base.OutputKey(label=name, position=idx) - output[key] = result - idx += 1 + assert result.ndim == 1 + key = base.OutputKey(label=name, position=idx) + output[key] = result + idx += 1 if not output: raise DataError("No numeric types to aggregate") @@ -1804,7 +1794,25 @@ def ohlc(self) -> DataFrame: DataFrame Open, high, low and close values within each group. """ - return self._apply_to_column_groupbys(lambda x: x._cython_agg_general("ohlc")) + if self.obj.ndim == 1: + # self._iterate_slices() yields only self._selected_obj + obj = self._selected_obj + + is_numeric = is_numeric_dtype(obj.dtype) + if not is_numeric: + raise DataError("No numeric types to aggregate") + + res_values = self.grouper._cython_operation( + "aggregate", obj._values, "ohlc", axis=0, min_count=-1 + ) + + agg_names = ["open", "high", "low", "close"] + result = self.obj._constructor_expanddim( + res_values, index=self.grouper.result_index, columns=agg_names + ) + return self._reindex_output(result) + + return self._apply_to_column_groupbys(lambda x: x.ohlc()) @final @doc(DataFrame.describe) diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py index 71e6aa38d60e5..bbe9ac6fa8094 100644 --- a/pandas/tests/resample/test_datetime_index.py +++ b/pandas/tests/resample/test_datetime_index.py @@ -58,14 +58,16 @@ def test_custom_grouper(index): g = s.groupby(b) # check all cython functions work - funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"] + g.ohlc() # doesn't use _cython_agg_general + funcs = ["add", "mean", "prod", "min", "max", "var"] for f in funcs: g._cython_agg_general(f) b = Grouper(freq=Minute(5), closed="right", label="right") g = s.groupby(b) # check all cython functions work - funcs = ["add", "mean", "prod", "ohlc", "min", "max", "var"] + g.ohlc() # doesn't use _cython_agg_general + funcs = ["add", "mean", "prod", "min", "max", "var"] for f in funcs: g._cython_agg_general(f) @@ -79,7 +81,7 @@ def test_custom_grouper(index): idx = DatetimeIndex(idx, freq="5T") expect = Series(arr, index=idx) - # GH2763 - return in put dtype if we can + # GH2763 - return input dtype if we can result = g.agg(np.sum) tm.assert_series_equal(result, expect)