Skip to content

Commit f6fc567

Browse files
authored
REF: consolidate casting in groupby agg_series (#41273)
1 parent d8996ad commit f6fc567

File tree

4 files changed

+21
-19
lines changed

4 files changed

+21
-19
lines changed

pandas/_libs/reduction.pyx

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ from pandas._libs.util cimport (
2121
set_array_not_contiguous,
2222
)
2323

24-
from pandas._libs.lib import (
25-
is_scalar,
26-
maybe_convert_objects,
27-
)
24+
from pandas._libs.lib import is_scalar
2825

2926

3027
cpdef check_result_array(object obj):
@@ -187,7 +184,6 @@ cdef class SeriesBinGrouper(_BaseGrouper):
187184
islider.reset()
188185
vslider.reset()
189186

190-
result = maybe_convert_objects(result)
191187
return result, counts
192188

193189

@@ -292,8 +288,6 @@ cdef class SeriesGrouper(_BaseGrouper):
292288
# have result initialized by this point.
293289
assert initialized, "`result` has not been initialized."
294290

295-
result = maybe_convert_objects(result)
296-
297291
return result, counts
298292

299293

pandas/core/groupby/ops.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -970,22 +970,33 @@ def agg_series(self, obj: Series, func: F) -> tuple[ArrayLike, np.ndarray]:
970970
# Caller is responsible for checking ngroups != 0
971971
assert self.ngroups != 0
972972

973+
cast_back = True
973974
if len(obj) == 0:
974975
# SeriesGrouper would raise if we were to call _aggregate_series_fast
975-
return self._aggregate_series_pure_python(obj, func)
976+
result, counts = self._aggregate_series_pure_python(obj, func)
976977

977978
elif is_extension_array_dtype(obj.dtype):
978979
# _aggregate_series_fast would raise TypeError when
979980
# calling libreduction.Slider
980981
# In the datetime64tz case it would incorrectly cast to tz-naive
981982
# TODO: can we get a performant workaround for EAs backed by ndarray?
982-
return self._aggregate_series_pure_python(obj, func)
983+
result, counts = self._aggregate_series_pure_python(obj, func)
983984

984985
elif obj.index._has_complex_internals:
985986
# Preempt TypeError in _aggregate_series_fast
986-
return self._aggregate_series_pure_python(obj, func)
987+
result, counts = self._aggregate_series_pure_python(obj, func)
987988

988-
return self._aggregate_series_fast(obj, func)
989+
else:
990+
result, counts = self._aggregate_series_fast(obj, func)
991+
cast_back = False
992+
993+
npvalues = lib.maybe_convert_objects(result, try_float=False)
994+
if cast_back:
995+
# TODO: Is there a documented reason why we dont always cast_back?
996+
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
997+
else:
998+
out = npvalues
999+
return out, counts
9891000

9901001
def _aggregate_series_fast(
9911002
self, obj: Series, func: F
@@ -1033,10 +1044,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
10331044
counts[i] = group.shape[0]
10341045
result[i] = res
10351046

1036-
npvalues = lib.maybe_convert_objects(result, try_float=False)
1037-
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
1038-
1039-
return out, counts
1047+
return result, counts
10401048

10411049

10421050
class BinGrouper(BaseGrouper):

pandas/core/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3086,7 +3086,7 @@ def combine(self, other, func, fill_value=None) -> Series:
30863086
new_values[:] = [func(lv, other) for lv in self._values]
30873087
new_name = self.name
30883088

3089-
# try_float=False is to match _aggregate_series_pure_python
3089+
# try_float=False is to match agg_series
30903090
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
30913091
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
30923092
return self._constructor(res_values, index=new_index, name=new_name)

pandas/tests/groupby/test_bin_groupby.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_series_grouper():
2020
grouper = libreduction.SeriesGrouper(obj, np.mean, labels, 2)
2121
result, counts = grouper.get_result()
2222

23-
expected = np.array([obj[3:6].mean(), obj[6:].mean()])
23+
expected = np.array([obj[3:6].mean(), obj[6:].mean()], dtype=object)
2424
tm.assert_almost_equal(result, expected)
2525

2626
exp_counts = np.array([3, 4], dtype=np.int64)
@@ -36,7 +36,7 @@ def test_series_grouper_result_length_difference():
3636
grouper = libreduction.SeriesGrouper(obj, lambda x: all(x > 0), labels, 2)
3737
result, counts = grouper.get_result()
3838

39-
expected = np.array([all(obj[3:6] > 0), all(obj[6:] > 0)])
39+
expected = np.array([all(obj[3:6] > 0), all(obj[6:] > 0)], dtype=object)
4040
tm.assert_equal(result, expected)
4141

4242
exp_counts = np.array([3, 4], dtype=np.int64)
@@ -61,7 +61,7 @@ def test_series_bin_grouper():
6161
grouper = libreduction.SeriesBinGrouper(obj, np.mean, bins)
6262
result, counts = grouper.get_result()
6363

64-
expected = np.array([obj[:3].mean(), obj[3:6].mean(), obj[6:].mean()])
64+
expected = np.array([obj[:3].mean(), obj[3:6].mean(), obj[6:].mean()], dtype=object)
6565
tm.assert_almost_equal(result, expected)
6666

6767
exp_counts = np.array([3, 3, 4], dtype=np.int64)

0 commit comments

Comments
 (0)