diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index c402ca194648f..83080aa98648f 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1,7 +1,6 @@ import builtins import datetime as dt from io import StringIO -from itertools import product from string import ascii_lowercase import numpy as np @@ -1296,36 +1295,32 @@ def __eq__(self, other): # -------------------------------- -def test_size(df): - grouped = df.groupby(["A", "B"]) +@pytest.mark.parametrize("by", ["A", "B", ["A", "B"]]) +def test_size(df, by): + grouped = df.groupby(by=by) result = grouped.size() for key, group in grouped: assert result[key] == len(group) - grouped = df.groupby("A") - result = grouped.size() - for key, group in grouped: - assert result[key] == len(group) - grouped = df.groupby("B") - result = grouped.size() - for key, group in grouped: - assert result[key] == len(group) +@pytest.mark.parametrize("by", ["A", "B", ["A", "B"]]) +@pytest.mark.parametrize("sort", [True, False]) +def test_size_sort(df, sort, by): + df = DataFrame(np.random.choice(20, (1000, 3)), columns=list("ABC")) + left = df.groupby(by=by, sort=sort).size() + right = df.groupby(by=by, sort=sort)["C"].apply(lambda a: a.shape[0]) + tm.assert_series_equal(left, right, check_names=False) - df = DataFrame(np.random.choice(20, (1000, 3)), columns=list("abc")) - for sort, key in product((False, True), ("a", "b", ["a", "b"])): - left = df.groupby(key, sort=sort).size() - right = df.groupby(key, sort=sort)["c"].apply(lambda a: a.shape[0]) - tm.assert_series_equal(left, right, check_names=False) - # GH11699 +def test_size_series_dataframe(): + # https://github.com/pandas-dev/pandas/issues/11699 df = DataFrame(columns=["A", "B"]) out = Series(dtype="int64", index=Index([], name="A")) tm.assert_series_equal(df.groupby("A").size(), out) def test_size_groupby_all_null(): - # GH23050 + # https://github.com/pandas-dev/pandas/issues/23050 # Assert no 'Value Error : Length of passed values is 2, index implies 0' df = DataFrame({"A": [None, None]}) # all-null groups result = df.groupby("A").size() @@ -1335,6 +1330,8 @@ def test_size_groupby_all_null(): # quantile # -------------------------------- + + @pytest.mark.parametrize( "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] )