diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 24f307f23f435..75ef2b820290b 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -29,6 +29,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ +- Added HTML representation for grouped DataFrame and Series (:issue:`34926`) - - diff --git a/pandas/core/config_init.py b/pandas/core/config_init.py index 27b898782fbef..992dbae21f0fd 100644 --- a/pandas/core/config_init.py +++ b/pandas/core/config_init.py @@ -110,6 +110,12 @@ def use_numba_cb(key): correct auto-detection. """ +pc_max_groups_doc = """ +: int + If max_groups is exceeded, switch to truncate groupby view. 'None' value + means unlimited. +""" + pc_min_rows_doc = """ : int The numbers of rows to show in a truncated view (when `max_rows` is @@ -355,6 +361,9 @@ def is_terminal() -> bool: validator=is_instance_factory((int, type(None))), ) cf.register_option("max_rows", 60, pc_max_rows_doc, validator=is_nonnegative_int) + cf.register_option( + "max_groups", 10, pc_max_groups_doc, validator=is_nonnegative_int + ) cf.register_option( "min_rows", 10, diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8fb50db2e33f2..b3bc308f5d3e1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -113,6 +113,8 @@ class providing the base-class of operations. if TYPE_CHECKING: from typing import Literal +from pandas.io.formats.format import repr_html_groupby + _common_see_also = """ See Also -------- @@ -601,6 +603,9 @@ def __repr__(self) -> str: # TODO: Better repr for GroupBy object return object.__repr__(self) + def _repr_html_(self) -> str: + return repr_html_groupby(self) + @final @property def groups(self) -> dict[Hashable, np.ndarray]: diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 83e0086958b9a..8df57eec177b5 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -2086,3 +2086,42 @@ def buffer_put_lines(buf: IO[str], lines: list[str]) -> None: if any(isinstance(x, str) for x in lines): lines = [str(x) for x in lines] buf.write("\n".join(lines)) + + +def repr_html_groupby(group_obj) -> str: + """ + Create an HTML representation for a grouped DataFrame or Series. + + Parameters + ---------- + group_obj : [DataFrameGroupBy, SeriesGroupBy] + Object to make HTML representation of. + Returns + ------- + str : + HTML representation of the input object. + """ + max_groups = get_option("display.max_groups") + max_rows = max( + 1, get_option("display.max_rows") // min(max_groups, group_obj.ngroups) + ) + group_names = list(group_obj.groups.keys()) + truncated = max_groups < group_obj.ngroups + if truncated: + n_start = (max_groups + 1) // 2 + n_end = max_groups - n_start + group_names = group_names[:n_start] + group_names[-n_end:] + repr_html_list = list() + for group_name in group_names: + if not isinstance(group_name, tuple): + group = group_obj.get_group((group_name, )) + else: + group = group_obj.get_group(group_name) + if not hasattr(group, "to_html"): + group = group.to_frame() + repr_html_list.append( + f"

Group Key: {group_name}

\n{group.to_html(max_rows=max_rows)}" + ) + if truncated: + repr_html_list.insert(max_groups // 2, "

...

") + return "\n".join(repr_html_list) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 0181481b29c44..7989491ab77bf 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -7,6 +7,7 @@ from pandas.compat import IS64 from pandas.errors import PerformanceWarning +import pandas.util._test_decorators as td import pandas as pd from pandas import ( @@ -2226,6 +2227,45 @@ def test_groups_repr_truncates(max_seq_items, expected): assert result == expected +@td.skip_if_no("lxml") +@pytest.mark.parametrize( + "n_groups,n_rows,check_n_groups,check_n_rows", + [ + (10, 60, 5, 3), # All groups and all rows in the groups are shown + (25, 100, 5, 2), # Not all groups are shown + (4, 400, 2, 7), # Not all rows are shown in the groups + (20, 400, 5, 3), # Not all groups and not all rows in the groups are shown + ], +) +def test_groupby_repr(n_groups, n_rows, check_n_groups, check_n_rows): + # GH 34926 + df = DataFrame( + { + "A": range(n_rows), + "B": range(0, n_rows * 2, 2), + "C": list(range(n_groups)) * (n_rows // n_groups), + } + ) + + gb = df.groupby("C") + + df_from_html = pd.concat(pd.read_html(StringIO(gb._repr_html_()), index_col=0)) + + # Drop "..." rows and convert index and data to int + df_from_html = df_from_html[df_from_html.index != "..."].astype(int) + df_from_html.index = df_from_html.index.astype(int) + + # Iterate over the first and last "check_n_groups" groups + gb_iter = list(gb)[:check_n_groups] + list(gb)[-check_n_groups:] + for group_name, df_group in gb_iter: + # Iterate over the first and last "check_n_rows" of every group + df_iter = pd.concat( + [df_group.iloc[:check_n_rows], df_group.iloc[-check_n_rows:]] + ).iterrows() + for index, row in df_iter: + tm.assert_series_equal(row, df_from_html.loc[index]) + + def test_group_on_two_row_multiindex_returns_one_tuple_key(): # GH 18451 df = DataFrame([{"a": 1, "b": 2, "c": 99}, {"a": 1, "b": 2, "c": 88}])