From 6b9ee217f199a51b3d06dbd3a78a7588b20f98c2 Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Thu, 4 Aug 2022 17:19:26 -0400 Subject: [PATCH 1/2] TST/CLN: Consolidate creation of groupby method args --- pandas/tests/groupby/__init__.py | 10 +++++++ pandas/tests/groupby/test_apply.py | 3 +- pandas/tests/groupby/test_categorical.py | 9 +++--- pandas/tests/groupby/test_function.py | 27 +++-------------- pandas/tests/groupby/test_groupby.py | 11 +++---- pandas/tests/groupby/test_groupby_subclass.py | 9 ++---- .../tests/groupby/transform/test_transform.py | 29 +++++-------------- 7 files changed, 35 insertions(+), 63 deletions(-) diff --git a/pandas/tests/groupby/__init__.py b/pandas/tests/groupby/__init__.py index e69de29bb2d1d..56bf90329a84a 100644 --- a/pandas/tests/groupby/__init__.py +++ b/pandas/tests/groupby/__init__.py @@ -0,0 +1,10 @@ +def get_method_args(name, obj): + if name in ("nth", "fillna", "take"): + return (0,) + if name == "quantile": + return (0.5,) + if name == "corrwith": + return (obj,) + if name == "tshift": + return (0, 0) + return () diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index 4cfc3ea41543b..91af60f1061e3 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -17,6 +17,7 @@ ) import pandas._testing as tm from pandas.core.api import Int64Index +from pandas.tests.groupby import get_method_args def test_apply_issues(): @@ -1069,7 +1070,7 @@ def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func): # Check output when another method is called before .apply() grp = df.groupby(by="a") - args = {"nth": [0], "corrwith": [df]}.get(reduction_func, []) + args = get_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): _ = getattr(grp, reduction_func)(*args) result = grp.apply(sum) diff --git a/pandas/tests/groupby/test_categorical.py b/pandas/tests/groupby/test_categorical.py index 004e55f4d161f..39d6169c39a2c 100644 --- a/pandas/tests/groupby/test_categorical.py +++ b/pandas/tests/groupby/test_categorical.py @@ -14,6 +14,7 @@ qcut, ) import pandas._testing as tm +from pandas.tests.groupby import get_method_args def cartesian_product_for_groupers(result, args, names, fill_value=np.NaN): @@ -1373,7 +1374,7 @@ def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed, r "value": [0.1] * 4, } ) - args = {"nth": [0]}.get(reduction_func, []) + args = get_method_args(reduction_func, df) expected_length = 4 if observed else 16 @@ -1409,7 +1410,7 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans( } ) unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")] - args = {"nth": [0]}.get(reduction_func, []) + args = get_method_args(reduction_func, df) series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"] agg = getattr(series_groupby, reduction_func) @@ -1450,7 +1451,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_fun df_grp = df.groupby(["cat_1", "cat_2"], observed=True) - args = {"nth": [0], "corrwith": [df]}.get(reduction_func, []) + args = get_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): res = getattr(df_grp, reduction_func)(*args) @@ -1482,7 +1483,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_false( df_grp = df.groupby(["cat_1", "cat_2"], observed=observed) - args = {"nth": [0], "corrwith": [df]}.get(reduction_func, []) + args = get_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): res = getattr(df_grp, reduction_func)(*args) diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index dda583e3a1962..fdc3c7ed83dd5 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -18,6 +18,7 @@ ) import pandas._testing as tm import pandas.core.nanops as nanops +from pandas.tests.groupby import get_method_args from pandas.util import _test_decorators as td @@ -570,7 +571,7 @@ def test_axis1_numeric_only(request, groupby_func, numeric_only): groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4] gb = df.groupby(groups) method = getattr(gb, groupby_func) - args = (0,) if groupby_func == "fillna" else () + args = get_method_args(groupby_func, df) kwargs = {"axis": 1} if numeric_only is not None: # when numeric_only is None we don't pass any argument @@ -1366,12 +1367,7 @@ def test_deprecate_numeric_only( # has_arg: Whether the op has a numeric_only arg df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]}) - if kernel == "corrwith": - args = (df,) - elif kernel == "nth" or kernel == "fillna": - args = (0,) - else: - args = () + args = get_method_args(kernel, df) kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} gb = df.groupby(keys) @@ -1451,22 +1447,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request): expected_gb = expected_ser.groupby(grouper) expected_method = getattr(expected_gb, groupby_func) - if groupby_func == "corrwith": - args = (ser,) - elif groupby_func == "corr": - args = (ser,) - elif groupby_func == "cov": - args = (ser,) - elif groupby_func == "nth": - args = (0,) - elif groupby_func == "fillna": - args = (True,) - elif groupby_func == "take": - args = ([0],) - elif groupby_func == "quantile": - args = (0.5,) - else: - args = () + args = get_method_args(groupby_func, ser) fails_on_numeric_object = ( "corr", diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 73aeb17d8c274..0c944b201dd16 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -29,6 +29,7 @@ from pandas.core.arrays import BooleanArray import pandas.core.common as com from pandas.core.groupby.base import maybe_normalize_deprecated_kernels +from pandas.tests.groupby import get_method_args def test_repr(): @@ -2366,14 +2367,10 @@ def test_dup_labels_output_shape(groupby_func, idx): df = DataFrame([[1, 1]], columns=idx) grp_by = df.groupby([0]) - args = [] - if groupby_func in {"fillna", "nth"}: - args.append(0) - elif groupby_func == "corrwith": - args.append(df) - elif groupby_func == "tshift": + if groupby_func == "tshift": df.index = [Timestamp("today")] - args.extend([1, "D"]) + # args.extend([1, "D"]) + args = get_method_args(groupby_func, df) with tm.assert_produces_warning(warn, match="is deprecated"): result = getattr(grp_by, groupby_func)(*args) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index b665843728165..bd2737084bca6 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -10,6 +10,7 @@ ) import pandas._testing as tm from pandas.core.groupby.base import maybe_normalize_deprecated_kernels +from pandas.tests.groupby import get_method_args @pytest.mark.parametrize( @@ -34,13 +35,7 @@ def test_groupby_preserves_subclass(obj, groupby_func): # Groups should preserve subclass type assert isinstance(grouped.get_group(0), type(obj)) - args = [] - if groupby_func in {"fillna", "nth"}: - args.append(0) - elif groupby_func == "corrwith": - args.append(obj) - elif groupby_func == "tshift": - args.extend([0, 0]) + args = get_method_args(groupby_func, obj) with tm.assert_produces_warning(warn, match="is deprecated"): result1 = getattr(grouped, groupby_func)(*args) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 5c64ba3d9e266..7eddf8dd5eea6 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -22,6 +22,7 @@ import pandas._testing as tm from pandas.core.groupby.base import maybe_normalize_deprecated_kernels from pandas.core.groupby.generic import DataFrameGroupBy +from pandas.tests.groupby import get_method_args def assert_fp_equal(a, b): @@ -172,14 +173,10 @@ def test_transform_axis_1(request, transformation_func): msg = "ngroup fails with axis=1: #45986" request.node.add_marker(pytest.mark.xfail(reason=msg)) - warn = None - if transformation_func == "tshift": - warn = FutureWarning - - request.node.add_marker(pytest.mark.xfail(reason="tshift is deprecated")) - args = ("ffill",) if transformation_func == "fillna" else () + warn = FutureWarning if transformation_func == "tshift" else None df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + args = get_method_args(transformation_func, df) with tm.assert_produces_warning(warn): result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args) expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T @@ -1168,7 +1165,7 @@ def test_transform_agg_by_name(request, reduction_func, obj): pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith") ) - args = {"nth": [0], "quantile": [0.5], "corrwith": [obj]}.get(func, []) + args = get_method_args(reduction_func, obj) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): result = g.transform(func, *args) @@ -1370,12 +1367,7 @@ def test_null_group_str_reducer(request, dropna, reduction_func): df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index) gb = df.groupby("A", dropna=dropna) - if reduction_func == "corrwith": - args = (df["B"],) - elif reduction_func == "nth": - args = (0,) - else: - args = () + args = get_method_args(reduction_func, df) # Manually handle reducers that don't fit the generic pattern # Set expected with dropna=False, then replace if necessary @@ -1418,8 +1410,8 @@ def test_null_group_str_transformer(request, dropna, transformation_func): if transformation_func == "tshift": msg = "tshift requires timeseries" request.node.add_marker(pytest.mark.xfail(reason=msg)) - args = (0,) if transformation_func == "fillna" else () df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3]) + args = get_method_args(transformation_func, df) gb = df.groupby("A", dropna=dropna) buffer = [] @@ -1461,12 +1453,7 @@ def test_null_group_str_reducer_series(request, dropna, reduction_func): ser = Series([1, 2, 2, 3], index=index) gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna) - if reduction_func == "corrwith": - args = (ser,) - elif reduction_func == "nth": - args = (0,) - else: - args = () + args = get_method_args(reduction_func, ser) # Manually handle reducers that don't fit the generic pattern # Set expected with dropna=False, then replace if necessary @@ -1506,8 +1493,8 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func) if transformation_func == "tshift": msg = "tshift requires timeseries" request.node.add_marker(pytest.mark.xfail(reason=msg)) - args = (0,) if transformation_func == "fillna" else () ser = Series([1, 2, 2], index=[1, 2, 3]) + args = get_method_args(transformation_func, ser) gb = ser.groupby([1, 1, np.nan], dropna=dropna) buffer = [] From e3fee66557c21aec5cbc9de1e0be0691ef9d153a Mon Sep 17 00:00:00 2001 From: asv-bot Date: Sat, 6 Aug 2022 10:41:41 -0400 Subject: [PATCH 2/2] Rename to get_groupby_method_args --- pandas/tests/groupby/__init__.py | 19 ++++++++++++++++++- pandas/tests/groupby/test_apply.py | 4 ++-- pandas/tests/groupby/test_categorical.py | 10 +++++----- pandas/tests/groupby/test_function.py | 8 ++++---- pandas/tests/groupby/test_groupby.py | 4 ++-- pandas/tests/groupby/test_groupby_subclass.py | 4 ++-- .../tests/groupby/transform/test_transform.py | 14 +++++++------- 7 files changed, 40 insertions(+), 23 deletions(-) diff --git a/pandas/tests/groupby/__init__.py b/pandas/tests/groupby/__init__.py index 56bf90329a84a..c63aa568a15dc 100644 --- a/pandas/tests/groupby/__init__.py +++ b/pandas/tests/groupby/__init__.py @@ -1,4 +1,21 @@ -def get_method_args(name, obj): +def get_groupby_method_args(name, obj): + """ + Get required arguments for a groupby method. + + When parametrizing a test over groupby methods (e.g. "sum", "mean", "fillna"), + it is often the case that arguments are required for certain methods. + + Parameters + ---------- + name: str + Name of the method. + obj: Series or DataFrame + pandas object that is being grouped. + + Returns + ------- + A tuple of required arguments for the method. + """ if name in ("nth", "fillna", "take"): return (0,) if name == "quantile": diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py index 91af60f1061e3..b064c12f89c21 100644 --- a/pandas/tests/groupby/test_apply.py +++ b/pandas/tests/groupby/test_apply.py @@ -17,7 +17,7 @@ ) import pandas._testing as tm from pandas.core.api import Int64Index -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args def test_apply_issues(): @@ -1070,7 +1070,7 @@ def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func): # Check output when another method is called before .apply() grp = df.groupby(by="a") - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): _ = getattr(grp, reduction_func)(*args) result = grp.apply(sum) diff --git a/pandas/tests/groupby/test_categorical.py b/pandas/tests/groupby/test_categorical.py index 39d6169c39a2c..6d22c676a3c16 100644 --- a/pandas/tests/groupby/test_categorical.py +++ b/pandas/tests/groupby/test_categorical.py @@ -14,7 +14,7 @@ qcut, ) import pandas._testing as tm -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args def cartesian_product_for_groupers(result, args, names, fill_value=np.NaN): @@ -1374,7 +1374,7 @@ def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed, r "value": [0.1] * 4, } ) - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) expected_length = 4 if observed else 16 @@ -1410,7 +1410,7 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans( } ) unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")] - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"] agg = getattr(series_groupby, reduction_func) @@ -1451,7 +1451,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_fun df_grp = df.groupby(["cat_1", "cat_2"], observed=True) - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): res = getattr(df_grp, reduction_func)(*args) @@ -1483,7 +1483,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_false( df_grp = df.groupby(["cat_1", "cat_2"], observed=observed) - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): res = getattr(df_grp, reduction_func)(*args) diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index fdc3c7ed83dd5..93e9b5bb776ab 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -18,7 +18,7 @@ ) import pandas._testing as tm import pandas.core.nanops as nanops -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args from pandas.util import _test_decorators as td @@ -571,7 +571,7 @@ def test_axis1_numeric_only(request, groupby_func, numeric_only): groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4] gb = df.groupby(groups) method = getattr(gb, groupby_func) - args = get_method_args(groupby_func, df) + args = get_groupby_method_args(groupby_func, df) kwargs = {"axis": 1} if numeric_only is not None: # when numeric_only is None we don't pass any argument @@ -1367,7 +1367,7 @@ def test_deprecate_numeric_only( # has_arg: Whether the op has a numeric_only arg df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]}) - args = get_method_args(kernel, df) + args = get_groupby_method_args(kernel, df) kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} gb = df.groupby(keys) @@ -1447,7 +1447,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request): expected_gb = expected_ser.groupby(grouper) expected_method = getattr(expected_gb, groupby_func) - args = get_method_args(groupby_func, ser) + args = get_groupby_method_args(groupby_func, ser) fails_on_numeric_object = ( "corr", diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 0c944b201dd16..a6ab13270c4dc 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -29,7 +29,7 @@ from pandas.core.arrays import BooleanArray import pandas.core.common as com from pandas.core.groupby.base import maybe_normalize_deprecated_kernels -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args def test_repr(): @@ -2370,7 +2370,7 @@ def test_dup_labels_output_shape(groupby_func, idx): if groupby_func == "tshift": df.index = [Timestamp("today")] # args.extend([1, "D"]) - args = get_method_args(groupby_func, df) + args = get_groupby_method_args(groupby_func, df) with tm.assert_produces_warning(warn, match="is deprecated"): result = getattr(grp_by, groupby_func)(*args) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py index bd2737084bca6..fddf0c86d0ab1 100644 --- a/pandas/tests/groupby/test_groupby_subclass.py +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -10,7 +10,7 @@ ) import pandas._testing as tm from pandas.core.groupby.base import maybe_normalize_deprecated_kernels -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args @pytest.mark.parametrize( @@ -35,7 +35,7 @@ def test_groupby_preserves_subclass(obj, groupby_func): # Groups should preserve subclass type assert isinstance(grouped.get_group(0), type(obj)) - args = get_method_args(groupby_func, obj) + args = get_groupby_method_args(groupby_func, obj) with tm.assert_produces_warning(warn, match="is deprecated"): result1 = getattr(grouped, groupby_func)(*args) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 7eddf8dd5eea6..d2928c52c33e2 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -22,7 +22,7 @@ import pandas._testing as tm from pandas.core.groupby.base import maybe_normalize_deprecated_kernels from pandas.core.groupby.generic import DataFrameGroupBy -from pandas.tests.groupby import get_method_args +from pandas.tests.groupby import get_groupby_method_args def assert_fp_equal(a, b): @@ -176,7 +176,7 @@ def test_transform_axis_1(request, transformation_func): warn = FutureWarning if transformation_func == "tshift" else None df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) - args = get_method_args(transformation_func, df) + args = get_groupby_method_args(transformation_func, df) with tm.assert_produces_warning(warn): result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args) expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T @@ -1165,7 +1165,7 @@ def test_transform_agg_by_name(request, reduction_func, obj): pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith") ) - args = get_method_args(reduction_func, obj) + args = get_groupby_method_args(reduction_func, obj) with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): result = g.transform(func, *args) @@ -1367,7 +1367,7 @@ def test_null_group_str_reducer(request, dropna, reduction_func): df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index) gb = df.groupby("A", dropna=dropna) - args = get_method_args(reduction_func, df) + args = get_groupby_method_args(reduction_func, df) # Manually handle reducers that don't fit the generic pattern # Set expected with dropna=False, then replace if necessary @@ -1411,7 +1411,7 @@ def test_null_group_str_transformer(request, dropna, transformation_func): msg = "tshift requires timeseries" request.node.add_marker(pytest.mark.xfail(reason=msg)) df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3]) - args = get_method_args(transformation_func, df) + args = get_groupby_method_args(transformation_func, df) gb = df.groupby("A", dropna=dropna) buffer = [] @@ -1453,7 +1453,7 @@ def test_null_group_str_reducer_series(request, dropna, reduction_func): ser = Series([1, 2, 2, 3], index=index) gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna) - args = get_method_args(reduction_func, ser) + args = get_groupby_method_args(reduction_func, ser) # Manually handle reducers that don't fit the generic pattern # Set expected with dropna=False, then replace if necessary @@ -1494,7 +1494,7 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func) msg = "tshift requires timeseries" request.node.add_marker(pytest.mark.xfail(reason=msg)) ser = Series([1, 2, 2], index=[1, 2, 3]) - args = get_method_args(transformation_func, ser) + args = get_groupby_method_args(transformation_func, ser) gb = ser.groupby([1, 1, np.nan], dropna=dropna) buffer = []