Skip to content

REF: consolidate casting in groupby agg_series #41273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions pandas/_libs/reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ from pandas._libs.util cimport (
set_array_not_contiguous,
)

from pandas._libs.lib import (
is_scalar,
maybe_convert_objects,
)
from pandas._libs.lib import is_scalar


cpdef check_result_array(object obj):
Expand Down Expand Up @@ -185,7 +182,6 @@ cdef class SeriesBinGrouper(_BaseGrouper):
islider.reset()
vslider.reset()

result = maybe_convert_objects(result)
return result, counts


Expand Down Expand Up @@ -288,8 +284,6 @@ cdef class SeriesGrouper(_BaseGrouper):
# have result initialized by this point.
assert initialized, "`result` has not been initialized."

result = maybe_convert_objects(result)

return result, counts


Expand Down
24 changes: 16 additions & 8 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,22 +970,33 @@ def agg_series(self, obj: Series, func: F) -> tuple[ArrayLike, np.ndarray]:
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0

cast_back = True
if len(obj) == 0:
# SeriesGrouper would raise if we were to call _aggregate_series_fast
return self._aggregate_series_pure_python(obj, func)
result, counts = self._aggregate_series_pure_python(obj, func)

elif is_extension_array_dtype(obj.dtype):
# _aggregate_series_fast would raise TypeError when
# calling libreduction.Slider
# In the datetime64tz case it would incorrectly cast to tz-naive
# TODO: can we get a performant workaround for EAs backed by ndarray?
return self._aggregate_series_pure_python(obj, func)
result, counts = self._aggregate_series_pure_python(obj, func)

elif obj.index._has_complex_internals:
# Preempt TypeError in _aggregate_series_fast
return self._aggregate_series_pure_python(obj, func)
result, counts = self._aggregate_series_pure_python(obj, func)

return self._aggregate_series_fast(obj, func)
else:
result, counts = self._aggregate_series_fast(obj, func)
cast_back = False

npvalues = lib.maybe_convert_objects(result, try_float=False)
if cast_back:
# TODO: Is there a documented reason why we dont always cast_back?
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
else:
out = npvalues
return out, counts

def _aggregate_series_fast(
self, obj: Series, func: F
Expand Down Expand Up @@ -1033,10 +1044,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
counts[i] = group.shape[0]
result[i] = res

npvalues = lib.maybe_convert_objects(result, try_float=False)
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)

return out, counts
return result, counts


class BinGrouper(BaseGrouper):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3086,7 +3086,7 @@ def combine(self, other, func, fill_value=None) -> Series:
new_values[:] = [func(lv, other) for lv in self._values]
new_name = self.name

# try_float=False is to match _aggregate_series_pure_python
# try_float=False is to match agg_series
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
return self._constructor(res_values, index=new_index, name=new_name)
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/groupby/test_bin_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_series_grouper():
grouper = libreduction.SeriesGrouper(obj, np.mean, labels, 2)
result, counts = grouper.get_result()

expected = np.array([obj[3:6].mean(), obj[6:].mean()])
expected = np.array([obj[3:6].mean(), obj[6:].mean()], dtype=object)
tm.assert_almost_equal(result, expected)

exp_counts = np.array([3, 4], dtype=np.int64)
Expand All @@ -36,7 +36,7 @@ def test_series_grouper_result_length_difference():
grouper = libreduction.SeriesGrouper(obj, lambda x: all(x > 0), labels, 2)
result, counts = grouper.get_result()

expected = np.array([all(obj[3:6] > 0), all(obj[6:] > 0)])
expected = np.array([all(obj[3:6] > 0), all(obj[6:] > 0)], dtype=object)
tm.assert_equal(result, expected)

exp_counts = np.array([3, 4], dtype=np.int64)
Expand All @@ -61,7 +61,7 @@ def test_series_bin_grouper():
grouper = libreduction.SeriesBinGrouper(obj, np.mean, bins)
result, counts = grouper.get_result()

expected = np.array([obj[:3].mean(), obj[3:6].mean(), obj[6:].mean()])
expected = np.array([obj[:3].mean(), obj[3:6].mean(), obj[6:].mean()], dtype=object)
tm.assert_almost_equal(result, expected)

exp_counts = np.array([3, 3, 4], dtype=np.int64)
Expand Down