Skip to content

TST/CLN: Consolidate creation of groupby method args #47973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pandas/tests/groupby/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def get_groupby_method_args(name, obj):
"""
Get required arguments for a groupby method.

When parametrizing a test over groupby methods (e.g. "sum", "mean", "fillna"),
it is often the case that arguments are required for certain methods.

Parameters
----------
name: str
Name of the method.
obj: Series or DataFrame
pandas object that is being grouped.

Returns
-------
A tuple of required arguments for the method.
"""
if name in ("nth", "fillna", "take"):
return (0,)
if name == "quantile":
return (0.5,)
if name == "corrwith":
return (obj,)
if name == "tshift":
return (0, 0)
return ()
3 changes: 2 additions & 1 deletion pandas/tests/groupby/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
import pandas._testing as tm
from pandas.core.api import Int64Index
from pandas.tests.groupby import get_groupby_method_args


def test_apply_issues():
Expand Down Expand Up @@ -1069,7 +1070,7 @@ def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func):

# Check output when another method is called before .apply()
grp = df.groupby(by="a")
args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
args = get_groupby_method_args(reduction_func, df)
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
_ = getattr(grp, reduction_func)(*args)
result = grp.apply(sum)
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
qcut,
)
import pandas._testing as tm
from pandas.tests.groupby import get_groupby_method_args


def cartesian_product_for_groupers(result, args, names, fill_value=np.NaN):
Expand Down Expand Up @@ -1373,7 +1374,7 @@ def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed, r
"value": [0.1] * 4,
}
)
args = {"nth": [0]}.get(reduction_func, [])
args = get_groupby_method_args(reduction_func, df)

expected_length = 4 if observed else 16

Expand Down Expand Up @@ -1409,7 +1410,7 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(
}
)
unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")]
args = {"nth": [0]}.get(reduction_func, [])
args = get_groupby_method_args(reduction_func, df)

series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"]
agg = getattr(series_groupby, reduction_func)
Expand Down Expand Up @@ -1450,7 +1451,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_fun

df_grp = df.groupby(["cat_1", "cat_2"], observed=True)

args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
args = get_groupby_method_args(reduction_func, df)
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
res = getattr(df_grp, reduction_func)(*args)

Expand Down Expand Up @@ -1482,7 +1483,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_false(

df_grp = df.groupby(["cat_1", "cat_2"], observed=observed)

args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
args = get_groupby_method_args(reduction_func, df)
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
res = getattr(df_grp, reduction_func)(*args)

Expand Down
27 changes: 4 additions & 23 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
import pandas._testing as tm
import pandas.core.nanops as nanops
from pandas.tests.groupby import get_groupby_method_args
from pandas.util import _test_decorators as td


Expand Down Expand Up @@ -570,7 +571,7 @@ def test_axis1_numeric_only(request, groupby_func, numeric_only):
groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4]
gb = df.groupby(groups)
method = getattr(gb, groupby_func)
args = (0,) if groupby_func == "fillna" else ()
args = get_groupby_method_args(groupby_func, df)
kwargs = {"axis": 1}
if numeric_only is not None:
# when numeric_only is None we don't pass any argument
Expand Down Expand Up @@ -1366,12 +1367,7 @@ def test_deprecate_numeric_only(
# has_arg: Whether the op has a numeric_only arg
df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]})

if kernel == "corrwith":
args = (df,)
elif kernel == "nth" or kernel == "fillna":
args = (0,)
else:
args = ()
args = get_groupby_method_args(kernel, df)
kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only}

gb = df.groupby(keys)
Expand Down Expand Up @@ -1451,22 +1447,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
expected_gb = expected_ser.groupby(grouper)
expected_method = getattr(expected_gb, groupby_func)

if groupby_func == "corrwith":
args = (ser,)
elif groupby_func == "corr":
args = (ser,)
elif groupby_func == "cov":
args = (ser,)
elif groupby_func == "nth":
args = (0,)
elif groupby_func == "fillna":
args = (True,)
elif groupby_func == "take":
args = ([0],)
elif groupby_func == "quantile":
args = (0.5,)
else:
args = ()
args = get_groupby_method_args(groupby_func, ser)

fails_on_numeric_object = (
"corr",
Expand Down
11 changes: 4 additions & 7 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pandas.core.arrays import BooleanArray
import pandas.core.common as com
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
from pandas.tests.groupby import get_groupby_method_args


def test_repr():
Expand Down Expand Up @@ -2366,14 +2367,10 @@ def test_dup_labels_output_shape(groupby_func, idx):
df = DataFrame([[1, 1]], columns=idx)
grp_by = df.groupby([0])

args = []
if groupby_func in {"fillna", "nth"}:
args.append(0)
elif groupby_func == "corrwith":
args.append(df)
elif groupby_func == "tshift":
if groupby_func == "tshift":
df.index = [Timestamp("today")]
args.extend([1, "D"])
# args.extend([1, "D"])
args = get_groupby_method_args(groupby_func, df)

with tm.assert_produces_warning(warn, match="is deprecated"):
result = getattr(grp_by, groupby_func)(*args)
Expand Down
9 changes: 2 additions & 7 deletions pandas/tests/groupby/test_groupby_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import pandas._testing as tm
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
from pandas.tests.groupby import get_groupby_method_args


@pytest.mark.parametrize(
Expand All @@ -34,13 +35,7 @@ def test_groupby_preserves_subclass(obj, groupby_func):
# Groups should preserve subclass type
assert isinstance(grouped.get_group(0), type(obj))

args = []
if groupby_func in {"fillna", "nth"}:
args.append(0)
elif groupby_func == "corrwith":
args.append(obj)
elif groupby_func == "tshift":
args.extend([0, 0])
args = get_groupby_method_args(groupby_func, obj)

with tm.assert_produces_warning(warn, match="is deprecated"):
result1 = getattr(grouped, groupby_func)(*args)
Expand Down
29 changes: 8 additions & 21 deletions pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pandas._testing as tm
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
from pandas.core.groupby.generic import DataFrameGroupBy
from pandas.tests.groupby import get_groupby_method_args


def assert_fp_equal(a, b):
Expand Down Expand Up @@ -172,14 +173,10 @@ def test_transform_axis_1(request, transformation_func):
msg = "ngroup fails with axis=1: #45986"
request.node.add_marker(pytest.mark.xfail(reason=msg))

warn = None
if transformation_func == "tshift":
warn = FutureWarning

request.node.add_marker(pytest.mark.xfail(reason="tshift is deprecated"))
args = ("ffill",) if transformation_func == "fillna" else ()
warn = FutureWarning if transformation_func == "tshift" else None

df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"])
args = get_groupby_method_args(transformation_func, df)
with tm.assert_produces_warning(warn):
result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args)
expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T
Expand Down Expand Up @@ -1168,7 +1165,7 @@ def test_transform_agg_by_name(request, reduction_func, obj):
pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith")
)

args = {"nth": [0], "quantile": [0.5], "corrwith": [obj]}.get(func, [])
args = get_groupby_method_args(reduction_func, obj)
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
result = g.transform(func, *args)

Expand Down Expand Up @@ -1370,12 +1367,7 @@ def test_null_group_str_reducer(request, dropna, reduction_func):
df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index)
gb = df.groupby("A", dropna=dropna)

if reduction_func == "corrwith":
args = (df["B"],)
elif reduction_func == "nth":
args = (0,)
else:
args = ()
args = get_groupby_method_args(reduction_func, df)

# Manually handle reducers that don't fit the generic pattern
# Set expected with dropna=False, then replace if necessary
Expand Down Expand Up @@ -1418,8 +1410,8 @@ def test_null_group_str_transformer(request, dropna, transformation_func):
if transformation_func == "tshift":
msg = "tshift requires timeseries"
request.node.add_marker(pytest.mark.xfail(reason=msg))
args = (0,) if transformation_func == "fillna" else ()
df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3])
args = get_groupby_method_args(transformation_func, df)
gb = df.groupby("A", dropna=dropna)

buffer = []
Expand Down Expand Up @@ -1461,12 +1453,7 @@ def test_null_group_str_reducer_series(request, dropna, reduction_func):
ser = Series([1, 2, 2, 3], index=index)
gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna)

if reduction_func == "corrwith":
args = (ser,)
elif reduction_func == "nth":
args = (0,)
else:
args = ()
args = get_groupby_method_args(reduction_func, ser)

# Manually handle reducers that don't fit the generic pattern
# Set expected with dropna=False, then replace if necessary
Expand Down Expand Up @@ -1506,8 +1493,8 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func)
if transformation_func == "tshift":
msg = "tshift requires timeseries"
request.node.add_marker(pytest.mark.xfail(reason=msg))
args = (0,) if transformation_func == "fillna" else ()
ser = Series([1, 2, 2], index=[1, 2, 3])
args = get_groupby_method_args(transformation_func, ser)
gb = ser.groupby([1, 1, np.nan], dropna=dropna)

buffer = []
Expand Down