Skip to content

Commit 8b0af58

Browse files
authored
ENH: Add support for min_count keyword for Resample and Groupby functions (#37870)
1 parent 92586ba commit 8b0af58

File tree

6 files changed

+40
-16
lines changed

6 files changed

+40
-16
lines changed

doc/source/whatsnew/v1.2.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ Other enhancements
252252
- :class:`Window` now supports all Scipy window types in ``win_type`` with flexible keyword argument support (:issue:`34556`)
253253
- :meth:`testing.assert_index_equal` now has a ``check_order`` parameter that allows indexes to be checked in an order-insensitive manner (:issue:`37478`)
254254
- :func:`read_csv` supports memory-mapping for compressed files (:issue:`37621`)
255+
- Add support for ``min_count`` keyword for :meth:`DataFrame.groupby` and :meth:`DataFrame.resample` for functions ``min``, ``max``, ``first`` and ``last`` (:issue:`37821`, :issue:`37768`)
255256
- Improve error reporting for :meth:`DataFrame.merge` when invalid merge column definitions were given (:issue:`16228`)
256257
- Improve numerical stability for :meth:`.Rolling.skew`, :meth:`.Rolling.kurt`, :meth:`Expanding.skew` and :meth:`Expanding.kurt` through implementation of Kahan summation (:issue:`6929`)
257258
- Improved error reporting for subsetting columns of a :class:`.DataFrameGroupBy` with ``axis=1`` (:issue:`37725`)

pandas/_libs/groupby.pyx

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -903,13 +903,12 @@ def group_last(rank_t[:, :] out,
903903
ndarray[int64_t, ndim=2] nobs
904904
bint runtime_error = False
905905

906-
assert min_count == -1, "'min_count' only used in add and prod"
907-
908906
# TODO(cython 3.0):
909907
# Instead of `labels.shape[0]` use `len(labels)`
910908
if not len(values) == labels.shape[0]:
911909
raise AssertionError("len(index) != len(labels)")
912910

911+
min_count = max(min_count, 1)
913912
nobs = np.zeros((<object>out).shape, dtype=np.int64)
914913
if rank_t is object:
915914
resx = np.empty((<object>out).shape, dtype=object)
@@ -939,7 +938,7 @@ def group_last(rank_t[:, :] out,
939938

940939
for i in range(ncounts):
941940
for j in range(K):
942-
if nobs[i, j] == 0:
941+
if nobs[i, j] < min_count:
943942
out[i, j] = NAN
944943
else:
945944
out[i, j] = resx[i, j]
@@ -961,7 +960,7 @@ def group_last(rank_t[:, :] out,
961960

962961
for i in range(ncounts):
963962
for j in range(K):
964-
if nobs[i, j] == 0:
963+
if nobs[i, j] < min_count:
965964
if rank_t is int64_t:
966965
out[i, j] = NPY_NAT
967966
elif rank_t is uint64_t:
@@ -986,7 +985,8 @@ def group_last(rank_t[:, :] out,
986985
def group_nth(rank_t[:, :] out,
987986
int64_t[:] counts,
988987
ndarray[rank_t, ndim=2] values,
989-
const int64_t[:] labels, int64_t rank=1
988+
const int64_t[:] labels,
989+
int64_t min_count=-1, int64_t rank=1
990990
):
991991
"""
992992
Only aggregates on axis=0
@@ -1003,6 +1003,7 @@ def group_nth(rank_t[:, :] out,
10031003
if not len(values) == labels.shape[0]:
10041004
raise AssertionError("len(index) != len(labels)")
10051005

1006+
min_count = max(min_count, 1)
10061007
nobs = np.zeros((<object>out).shape, dtype=np.int64)
10071008
if rank_t is object:
10081009
resx = np.empty((<object>out).shape, dtype=object)
@@ -1033,7 +1034,7 @@ def group_nth(rank_t[:, :] out,
10331034

10341035
for i in range(ncounts):
10351036
for j in range(K):
1036-
if nobs[i, j] == 0:
1037+
if nobs[i, j] < min_count:
10371038
out[i, j] = NAN
10381039
else:
10391040
out[i, j] = resx[i, j]
@@ -1057,7 +1058,7 @@ def group_nth(rank_t[:, :] out,
10571058

10581059
for i in range(ncounts):
10591060
for j in range(K):
1060-
if nobs[i, j] == 0:
1061+
if nobs[i, j] < min_count:
10611062
if rank_t is int64_t:
10621063
out[i, j] = NPY_NAT
10631064
elif rank_t is uint64_t:
@@ -1294,13 +1295,12 @@ def group_max(groupby_t[:, :] out,
12941295
bint runtime_error = False
12951296
int64_t[:, :] nobs
12961297

1297-
assert min_count == -1, "'min_count' only used in add and prod"
1298-
12991298
# TODO(cython 3.0):
13001299
# Instead of `labels.shape[0]` use `len(labels)`
13011300
if not len(values) == labels.shape[0]:
13021301
raise AssertionError("len(index) != len(labels)")
13031302

1303+
min_count = max(min_count, 1)
13041304
nobs = np.zeros((<object>out).shape, dtype=np.int64)
13051305

13061306
maxx = np.empty_like(out)
@@ -1337,11 +1337,12 @@ def group_max(groupby_t[:, :] out,
13371337

13381338
for i in range(ncounts):
13391339
for j in range(K):
1340-
if nobs[i, j] == 0:
1340+
if nobs[i, j] < min_count:
13411341
if groupby_t is uint64_t:
13421342
runtime_error = True
13431343
break
13441344
else:
1345+
13451346
out[i, j] = nan_val
13461347
else:
13471348
out[i, j] = maxx[i, j]
@@ -1369,13 +1370,12 @@ def group_min(groupby_t[:, :] out,
13691370
bint runtime_error = False
13701371
int64_t[:, :] nobs
13711372

1372-
assert min_count == -1, "'min_count' only used in add and prod"
1373-
13741373
# TODO(cython 3.0):
13751374
# Instead of `labels.shape[0]` use `len(labels)`
13761375
if not len(values) == labels.shape[0]:
13771376
raise AssertionError("len(index) != len(labels)")
13781377

1378+
min_count = max(min_count, 1)
13791379
nobs = np.zeros((<object>out).shape, dtype=np.int64)
13801380

13811381
minx = np.empty_like(out)
@@ -1411,7 +1411,7 @@ def group_min(groupby_t[:, :] out,
14111411

14121412
for i in range(ncounts):
14131413
for j in range(K):
1414-
if nobs[i, j] == 0:
1414+
if nobs[i, j] < min_count:
14151415
if groupby_t is uint64_t:
14161416
runtime_error = True
14171417
break

pandas/core/groupby/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def _aggregate(
603603
):
604604
if agg_func is libgroupby.group_nth:
605605
# different signature from the others
606-
agg_func(result, counts, values, comp_ids, rank=1)
606+
agg_func(result, counts, values, comp_ids, min_count, rank=1)
607607
else:
608608
agg_func(result, counts, values, comp_ids, min_count)
609609

pandas/core/resample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def quantile(self, q=0.5, **kwargs):
950950

951951

952952
# downsample methods
953-
for method in ["sum", "prod"]:
953+
for method in ["sum", "prod", "min", "max", "first", "last"]:
954954

955955
def f(self, _method=method, min_count=0, *args, **kwargs):
956956
nv.validate_resampler_func(_method, args, kwargs)
@@ -961,7 +961,7 @@ def f(self, _method=method, min_count=0, *args, **kwargs):
961961

962962

963963
# downsample methods
964-
for method in ["min", "max", "first", "last", "mean", "sem", "median", "ohlc"]:
964+
for method in ["mean", "sem", "median", "ohlc"]:
965965

966966
def g(self, _method=method, *args, **kwargs):
967967
nv.validate_resampler_func(_method, args, kwargs)

pandas/tests/groupby/test_missing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,13 @@ def test_ffill_handles_nan_groups(dropna, method, has_nan_group):
116116
expected = df_without_nan_rows.reindex(ridx).reset_index(drop=True)
117117

118118
tm.assert_frame_equal(result, expected)
119+
120+
121+
@pytest.mark.parametrize("min_count, value", [(2, np.nan), (-1, 1.0)])
122+
@pytest.mark.parametrize("func", ["first", "last", "max", "min"])
123+
def test_min_count(func, min_count, value):
124+
# GH#37821
125+
df = DataFrame({"a": [1] * 3, "b": [1, np.nan, np.nan], "c": [np.nan] * 3})
126+
result = getattr(df.groupby("a"), func)(min_count=min_count)
127+
expected = DataFrame({"b": [value], "c": [np.nan]}, index=Index([1], name="a"))
128+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_datetime_index.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,3 +1785,16 @@ def test_resample_calendar_day_with_dst(
17851785
1.0, pd.date_range(first, exp_last, freq=freq_out, tz="Europe/Amsterdam")
17861786
)
17871787
tm.assert_series_equal(result, expected)
1788+
1789+
1790+
@pytest.mark.parametrize("func", ["min", "max", "first", "last"])
1791+
def test_resample_aggregate_functions_min_count(func):
1792+
# GH#37768
1793+
index = date_range(start="2020", freq="M", periods=3)
1794+
ser = Series([1, np.nan, np.nan], index)
1795+
result = getattr(ser.resample("Q"), func)(min_count=2)
1796+
expected = Series(
1797+
[np.nan],
1798+
index=DatetimeIndex(["2020-03-31"], dtype="datetime64[ns]", freq="Q-DEC"),
1799+
)
1800+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)