From 77eb2d2ea3d701d1d619c3fab6ebf94eded6a203 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 21 Apr 2021 12:38:05 -0700 Subject: [PATCH] REF: hide ArrayManager implementation details from GroupBy --- pandas/core/groupby/generic.py | 15 +-------------- pandas/core/internals/array_manager.py | 8 ++++++++ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4a721ae0d4bf6..1de5dff321400 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -97,7 +97,6 @@ all_indexes_same, ) import pandas.core.indexes.base as ibase -from pandas.core.internals import ArrayManager from pandas.core.series import Series from pandas.core.util.numba_ import maybe_use_numba @@ -1100,8 +1099,6 @@ def _cython_agg_manager( if numeric_only: data = data.get_numeric_data(copy=False) - using_array_manager = isinstance(data, ArrayManager) - def cast_agg_result( result: ArrayLike, values: ArrayLike, how: str ) -> ArrayLike: @@ -1114,11 +1111,7 @@ def cast_agg_result( result = type(values)._from_sequence(result.ravel(), dtype=values.dtype) # Note this will have result.dtype == dtype from above - elif ( - not using_array_manager - and isinstance(result, np.ndarray) - and result.ndim == 1 - ): + elif isinstance(result, np.ndarray) and result.ndim == 1: # We went through a SeriesGroupByPath and need to reshape # GH#32223 includes case with IntegerArray values # We only get here with values.dtype == object @@ -1822,8 +1815,6 @@ def count(self) -> DataFrame: ids, _, ngroups = self.grouper.group_info mask = ids != -1 - using_array_manager = isinstance(data, ArrayManager) - def hfunc(bvalues: ArrayLike) -> ArrayLike: # TODO(2DEA): reshape would not be necessary with 2D EAs if bvalues.ndim == 1: @@ -1833,10 +1824,6 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike: masked = mask & ~isna(bvalues) counted = lib.count_level_2d(masked, labels=ids, max_bin=ngroups, axis=1) - if using_array_manager: - # count_level_2d return (1, N) array for single column - # -> extract 1D array - counted = counted[0, :] return counted new_mgr = data.grouped_reduce(hfunc) diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 8c9902d330eee..f0ee1b56ec446 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -219,12 +219,20 @@ def grouped_reduce(self: T, func: Callable, ignore_failures: bool = False) -> T: result_indices: list[int] = [] for i, arr in enumerate(self.arrays): + # grouped_reduce functions all expect 2D arrays + arr = ensure_block_shape(arr, ndim=2) try: res = func(arr) except (TypeError, NotImplementedError): if not ignore_failures: raise continue + + if res.ndim == 2: + # reverse of ensure_block_shape + assert res.shape[0] == 1 + res = res[0] + result_arrays.append(res) result_indices.append(i)