Skip to content

Commit 18f43e8

Browse files
feat: add groupby cumcount (#1798)
1 parent 3edc313 commit 18f43e8

File tree

5 files changed

+106
-24
lines changed

5 files changed

+106
-24
lines changed

bigframes/core/array_value.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,23 @@ def project_window_op(
403403
never_skip_nulls: will disable null skipping for operators that would otherwise do so
404404
skip_reproject_unsafe: skips the reprojection step, can be used when performing many non-dependent window operations, user responsible for not nesting window expressions, or using outputs as join, filter or aggregation keys before a reprojection
405405
"""
406+
407+
return self.project_window_expr(
408+
ex.UnaryAggregation(op, ex.deref(column_name)),
409+
window_spec,
410+
never_skip_nulls,
411+
skip_reproject_unsafe,
412+
)
413+
414+
def project_window_expr(
415+
self,
416+
expression: ex.Aggregation,
417+
window: WindowSpec,
418+
never_skip_nulls=False,
419+
skip_reproject_unsafe: bool = False,
420+
):
406421
# TODO: Support non-deterministic windowing
407-
if window_spec.is_row_bounded or not op.order_independent:
422+
if window.is_row_bounded or not expression.op.order_independent:
408423
if self.node.order_ambiguous and not self.session._strictly_ordered:
409424
if not self.session._allows_ambiguity:
410425
raise ValueError(
@@ -415,14 +430,13 @@ def project_window_op(
415430
"Window ordering may be ambiguous, this can cause unstable results."
416431
)
417432
warnings.warn(msg, category=bfe.AmbiguousWindowWarning)
418-
419433
output_name = self._gen_namespaced_uid()
420434
return (
421435
ArrayValue(
422436
nodes.WindowOpNode(
423437
child=self.node,
424-
expression=ex.UnaryAggregation(op, ex.deref(column_name)),
425-
window_spec=window_spec,
438+
expression=expression,
439+
window_spec=window,
426440
output_name=ids.ColumnId(output_name),
427441
never_skip_nulls=never_skip_nulls,
428442
skip_reproject_unsafe=skip_reproject_unsafe,

bigframes/core/blocks.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,16 +1012,34 @@ def apply_window_op(
10121012
skip_null_groups: bool = False,
10131013
skip_reproject_unsafe: bool = False,
10141014
never_skip_nulls: bool = False,
1015+
) -> typing.Tuple[Block, str]:
1016+
agg_expr = ex.UnaryAggregation(op, ex.deref(column))
1017+
return self.apply_analytic(
1018+
agg_expr,
1019+
window_spec,
1020+
result_label,
1021+
skip_reproject_unsafe=skip_reproject_unsafe,
1022+
never_skip_nulls=never_skip_nulls,
1023+
skip_null_groups=skip_null_groups,
1024+
)
1025+
1026+
def apply_analytic(
1027+
self,
1028+
agg_expr: ex.Aggregation,
1029+
window: windows.WindowSpec,
1030+
result_label: Label,
1031+
*,
1032+
skip_reproject_unsafe: bool = False,
1033+
never_skip_nulls: bool = False,
1034+
skip_null_groups: bool = False,
10151035
) -> typing.Tuple[Block, str]:
10161036
block = self
10171037
if skip_null_groups:
1018-
for key in window_spec.grouping_keys:
1019-
block, not_null_id = block.apply_unary_op(key.id.name, ops.notnull_op)
1020-
block = block.filter_by_id(not_null_id).drop_columns([not_null_id])
1021-
expr, result_id = block._expr.project_window_op(
1022-
column,
1023-
op,
1024-
window_spec,
1038+
for key in window.grouping_keys:
1039+
block = block.filter(ops.notnull_op.as_expr(key.id.name))
1040+
expr, result_id = block._expr.project_window_expr(
1041+
agg_expr,
1042+
window,
10251043
skip_reproject_unsafe=skip_reproject_unsafe,
10261044
never_skip_nulls=never_skip_nulls,
10271045
)

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,27 @@ def count(self) -> df.DataFrame:
275275
def nunique(self) -> df.DataFrame:
276276
return self._aggregate_all(agg_ops.nunique_op)
277277

278+
@validations.requires_ordering()
279+
def cumcount(self, ascending: bool = True) -> series.Series:
280+
window_spec = (
281+
window_specs.cumulative_rows(grouping_keys=tuple(self._by_col_ids))
282+
if ascending
283+
else window_specs.inverse_cumulative_rows(
284+
grouping_keys=tuple(self._by_col_ids)
285+
)
286+
)
287+
block, result_id = self._block.apply_analytic(
288+
ex.NullaryAggregation(agg_ops.size_op),
289+
window=window_spec,
290+
result_label=None,
291+
)
292+
result = series.Series(block.select_column(result_id)) - 1
293+
if self._dropna and (len(self._by_col_ids) == 1):
294+
result = result.mask(
295+
series.Series(block.select_column(self._by_col_ids[0])).isna()
296+
)
297+
return result
298+
278299
@validations.requires_ordering()
279300
def cumsum(self, *args, numeric_only: bool = False, **kwargs) -> df.DataFrame:
280301
if not numeric_only:
@@ -546,10 +567,12 @@ def _apply_window_op(
546567
)
547568
columns, _ = self._aggregated_columns(numeric_only=numeric_only)
548569
block, result_ids = self._block.multi_apply_window_op(
549-
columns, op, window_spec=window_spec
570+
columns,
571+
op,
572+
window_spec=window_spec,
550573
)
551-
block = block.select_columns(result_ids)
552-
return df.DataFrame(block)
574+
result = df.DataFrame(block.select_columns(result_ids))
575+
return result
553576

554577
def _resolve_label(self, label: blocks.Label) -> str:
555578
"""Resolve label to column id."""

tests/system/small/test_groupby.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,14 +383,14 @@ def test_dataframe_groupby_multi_sum(
383383

384384

385385
@pytest.mark.parametrize(
386-
("operator"),
386+
("operator", "dropna"),
387387
[
388-
(lambda x: x.cumsum(numeric_only=True)),
389-
(lambda x: x.cummax(numeric_only=True)),
390-
(lambda x: x.cummin(numeric_only=True)),
388+
(lambda x: x.cumsum(numeric_only=True), True),
389+
(lambda x: x.cummax(numeric_only=True), True),
390+
(lambda x: x.cummin(numeric_only=True), False),
391391
# Pre-pandas 2.2 doesn't always proeduce float.
392-
(lambda x: x.cumprod().astype("Float64")),
393-
(lambda x: x.shift(periods=2)),
392+
(lambda x: x.cumprod().astype("Float64"), False),
393+
(lambda x: x.shift(periods=2), True),
394394
],
395395
ids=[
396396
"cumsum",
@@ -401,16 +401,44 @@ def test_dataframe_groupby_multi_sum(
401401
],
402402
)
403403
def test_dataframe_groupby_analytic(
404-
scalars_df_index, scalars_pandas_df_index, operator
404+
scalars_df_index,
405+
scalars_pandas_df_index,
406+
operator,
407+
dropna,
405408
):
406409
col_names = ["float64_col", "int64_col", "bool_col", "string_col"]
407-
bf_result = operator(scalars_df_index[col_names].groupby("string_col"))
408-
pd_result = operator(scalars_pandas_df_index[col_names].groupby("string_col"))
410+
bf_result = operator(
411+
scalars_df_index[col_names].groupby("string_col", dropna=dropna)
412+
)
413+
pd_result = operator(
414+
scalars_pandas_df_index[col_names].groupby("string_col", dropna=dropna)
415+
)
409416
bf_result_computed = bf_result.to_pandas()
410417

411418
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
412419

413420

421+
@pytest.mark.parametrize(
422+
("ascending", "dropna"),
423+
[
424+
(True, True),
425+
(False, False),
426+
],
427+
)
428+
def test_dataframe_groupby_cumcount(
429+
scalars_df_index, scalars_pandas_df_index, ascending, dropna
430+
):
431+
bf_result = scalars_df_index.groupby("string_col", dropna=dropna).cumcount(
432+
ascending
433+
)
434+
pd_result = scalars_pandas_df_index.groupby("string_col", dropna=dropna).cumcount(
435+
ascending
436+
)
437+
bf_result_computed = bf_result.to_pandas()
438+
439+
pd.testing.assert_series_equal(pd_result, bf_result_computed, check_dtype=False)
440+
441+
414442
def test_dataframe_groupby_size_as_index_false(
415443
scalars_df_index, scalars_pandas_df_index
416444
):

third_party/bigframes_vendored/pandas/core/groupby/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,6 @@ def max(
718718
def cumcount(self, ascending: bool = True):
719719
"""
720720
Number each item in each group from 0 to the length of that group - 1.
721-
(DataFrameGroupBy functionality is not yet available.)
722721
723722
**Examples:**
724723

0 commit comments

Comments
 (0)