diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index d114f26788f00..26dac44f0d15f 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -29,6 +29,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ +- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`) - :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`) - Additional options added to :meth:`.Styler.bar` to control alignment and display (:issue:`26070`) - :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview ` for performance and functional benefits (:issue:`42273`) diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index 764868c247472..990ccbc2a015b 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -1237,6 +1237,11 @@ def box(self, by=None, **kwargs): ---------- by : str or sequence Column in the DataFrame to group by. + + .. versionchanged:: 1.4.0 + + Previously, `by` is silently ignore and makes no groupings + **kwargs Additional keywords are documented in :meth:`DataFrame.plot`. @@ -1278,6 +1283,11 @@ def hist(self, by=None, bins=10, **kwargs): ---------- by : str or sequence, optional Column in the DataFrame to group by. + + .. versionchanged:: 1.4.0 + + Previously, `by` is silently ignore and makes no groupings + bins : int, default 10 Number of histogram bins to be used. **kwargs @@ -1309,6 +1319,16 @@ def hist(self, by=None, bins=10, **kwargs): ... columns = ['one']) >>> df['two'] = df['one'] + np.random.randint(1, 7, 6000) >>> ax = df.plot.hist(bins=12, alpha=0.5) + + A grouped histogram can be generated by providing the parameter `by` (which + can be a column name, or a list of column names): + + .. plot:: + :context: close-figs + + >>> age_list = [8, 10, 12, 14, 72, 74, 76, 78, 20, 25, 30, 35, 60, 85] + >>> df = pd.DataFrame({"gender": list("MMMMMMMMFFFFFF"), "age": age_list}) + >>> ax = df.plot.hist(column=["age"], by="gender", figsize=(10, 8)) """ return self(kind="hist", by=by, bins=bins, **kwargs) diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index 21f30c1311e17..8b4cf158ac827 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -18,6 +18,7 @@ LinePlot, MPLPlot, ) +from pandas.plotting._matplotlib.groupby import create_iter_data_given_by from pandas.plotting._matplotlib.style import get_standard_colors from pandas.plotting._matplotlib.tools import ( create_subplots, @@ -135,18 +136,37 @@ def _make_plot(self): if self.subplots: self._return_obj = pd.Series(dtype=object) - for i, (label, y) in enumerate(self._iter_data()): + # Re-create iterated data if `by` is assigned by users + data = ( + create_iter_data_given_by(self.data, self._kind) + if self.by is not None + else self.data + ) + + for i, (label, y) in enumerate(self._iter_data(data=data)): ax = self._get_ax(i) kwds = self.kwds.copy() + # When by is applied, show title for subplots to know which group it is + # just like df.boxplot, and need to apply T on y to provide right input + if self.by is not None: + y = y.T + ax.set_title(pprint_thing(label)) + + # When `by` is assigned, the ticklabels will become unique grouped + # values, instead of label which is used as subtitle in this case. + ticklabels = [ + pprint_thing(col) for col in self.data.columns.levels[0] + ] + else: + ticklabels = [pprint_thing(label)] + ret, bp = self._plot( ax, y, column_num=i, return_type=self.return_type, **kwds ) self.maybe_color_bp(bp) self._return_obj[label] = ret - - label = [pprint_thing(label)] - self._set_ticklabels(ax, label) + self._set_ticklabels(ax, ticklabels) else: y = self.data.values.T ax = self._get_ax(0) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 690e39de2ddb2..ff76bd771d1c0 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -9,6 +9,7 @@ from matplotlib.artist import Artist import numpy as np +from pandas._typing import IndexLabel from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly @@ -38,10 +39,12 @@ ) import pandas.core.common as com +from pandas.core.frame import DataFrame from pandas.io.formats.printing import pprint_thing from pandas.plotting._matplotlib.compat import mpl_ge_3_0_0 from pandas.plotting._matplotlib.converter import register_pandas_matplotlib_converters +from pandas.plotting._matplotlib.groupby import reconstruct_data_with_by from pandas.plotting._matplotlib.style import get_standard_colors from pandas.plotting._matplotlib.timeseries import ( decorate_axes, @@ -99,7 +102,7 @@ def __init__( self, data, kind=None, - by=None, + by: IndexLabel | None = None, subplots=False, sharex=None, sharey=False, @@ -124,13 +127,42 @@ def __init__( table=False, layout=None, include_bool=False, + column: IndexLabel | None = None, **kwds, ): import matplotlib.pyplot as plt self.data = data - self.by = by + + # if users assign an empty list or tuple, raise `ValueError` + # similar to current `df.box` and `df.hist` APIs. + if by in ([], ()): + raise ValueError("No group keys passed!") + self.by = com.maybe_make_list(by) + + # Assign the rest of columns into self.columns if by is explicitly defined + # while column is not, only need `columns` in hist/box plot when it's DF + # TODO: Might deprecate `column` argument in future PR (#28373) + if isinstance(data, DataFrame): + if column: + self.columns = com.maybe_make_list(column) + else: + if self.by is None: + self.columns = [ + col for col in data.columns if is_numeric_dtype(data[col]) + ] + else: + self.columns = [ + col + for col in data.columns + if col not in self.by and is_numeric_dtype(data[col]) + ] + + # For `hist` plot, need to get grouped original data before `self.data` is + # updated later + if self.by is not None and self._kind == "hist": + self._grouped = data.groupby(self.by) self.kind = kind @@ -139,7 +171,9 @@ def __init__( self.subplots = subplots if sharex is None: - if ax is None: + + # if by is defined, subplots are used and sharex should be False + if ax is None and by is None: self.sharex = True else: # if we get an axis, the users should do the visibility @@ -273,8 +307,15 @@ def _iter_data(self, data=None, keep_index=False, fillna=None): @property def nseries(self) -> int: + + # When `by` is explicitly assigned, grouped data size will be defined, and + # this will determine number of subplots to have, aka `self.nseries` if self.data.ndim == 1: return 1 + elif self.by is not None and self._kind == "hist": + return len(self._grouped) + elif self.by is not None and self._kind == "box": + return len(self.columns) else: return self.data.shape[1] @@ -420,6 +461,14 @@ def _compute_plot_data(self): if label is None and data.name is None: label = "None" data = data.to_frame(name=label) + elif self._kind in ("hist", "box"): + cols = self.columns if self.by is None else self.columns + self.by + data = data.loc[:, cols] + + # GH15079 reconstruct data if by is defined + if self.by is not None: + self.subplots = True + data = reconstruct_data_with_by(self.data, by=self.by, cols=self.columns) # GH16953, _convert is needed as fallback, for ``Series`` # with ``dtype == object`` diff --git a/pandas/plotting/_matplotlib/groupby.py b/pandas/plotting/_matplotlib/groupby.py new file mode 100644 index 0000000000000..37cc3186fe097 --- /dev/null +++ b/pandas/plotting/_matplotlib/groupby.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import numpy as np + +from pandas._typing import ( + Dict, + IndexLabel, +) + +from pandas.core.dtypes.missing import remove_na_arraylike + +from pandas import ( + DataFrame, + MultiIndex, + Series, + concat, +) + + +def create_iter_data_given_by( + data: DataFrame, kind: str = "hist" +) -> Dict[str, DataFrame | Series]: + """ + Create data for iteration given `by` is assigned or not, and it is only + used in both hist and boxplot. + + If `by` is assigned, return a dictionary of DataFrames in which the key of + dictionary is the values in groups. + If `by` is not assigned, return input as is, and this preserves current + status of iter_data. + + Parameters + ---------- + data : reformatted grouped data from `_compute_plot_data` method. + kind : str, plot kind. This function is only used for `hist` and `box` plots. + + Returns + ------- + iter_data : DataFrame or Dictionary of DataFrames + + Examples + -------- + If `by` is assigned: + + >>> import numpy as np + >>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')] + >>> mi = MultiIndex.from_tuples(tuples) + >>> value = [[1, 3, np.nan, np.nan], + ... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]] + >>> data = DataFrame(value, columns=mi) + >>> create_iter_data_given_by(data) + {'h1': DataFrame({'a': [1, 3, np.nan], 'b': [3, 4, np.nan]}), + 'h2': DataFrame({'a': [np.nan, np.nan, 5], 'b': [np.nan, np.nan, 6]})} + """ + + # For `hist` plot, before transformation, the values in level 0 are values + # in groups and subplot titles, and later used for column subselection and + # iteration; For `box` plot, values in level 1 are column names to show, + # and are used for iteration and as subplots titles. + if kind == "hist": + level = 0 + else: + level = 1 + + # Select sub-columns based on the value of level of MI, and if `by` is + # assigned, data must be a MI DataFrame + assert isinstance(data.columns, MultiIndex) + return { + col: data.loc[:, data.columns.get_level_values(level) == col] + for col in data.columns.levels[level] + } + + +def reconstruct_data_with_by( + data: DataFrame, by: IndexLabel, cols: IndexLabel +) -> DataFrame: + """ + Internal function to group data, and reassign multiindex column names onto the + result in order to let grouped data be used in _compute_plot_data method. + + Parameters + ---------- + data : Original DataFrame to plot + by : grouped `by` parameter selected by users + cols : columns of data set (excluding columns used in `by`) + + Returns + ------- + Output is the reconstructed DataFrame with MultiIndex columns. The first level + of MI is unique values of groups, and second level of MI is the columns + selected by users. + + Examples + -------- + >>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]} + >>> df = DataFrame(d) + >>> reconstruct_data_with_by(df, by='h', cols=['a', 'b']) + h1 h2 + a b a b + 0 1 3 NaN NaN + 1 3 4 NaN NaN + 2 NaN NaN 5 6 + """ + grouped = data.groupby(by) + + data_list = [] + for key, group in grouped: + columns = MultiIndex.from_product([[key], cols]) + sub_group = group[cols] + sub_group.columns = columns + data_list.append(sub_group) + + data = concat(data_list, axis=1) + return data + + +def reformat_hist_y_given_by( + y: Series | np.ndarray, by: IndexLabel | None +) -> Series | np.ndarray: + """Internal function to reformat y given `by` is applied or not for hist plot. + + If by is None, input y is 1-d with NaN removed; and if by is not None, groupby + will take place and input y is multi-dimensional array. + """ + if by is not None and len(y.shape) > 1: + return np.array([remove_na_arraylike(col) for col in y.T]).T + return remove_na_arraylike(y) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index a02d9a2b9dc8d..08cffbf475db0 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -17,11 +17,17 @@ remove_na_arraylike, ) +from pandas.core.frame import DataFrame + from pandas.io.formats.printing import pprint_thing from pandas.plotting._matplotlib.core import ( LinePlot, MPLPlot, ) +from pandas.plotting._matplotlib.groupby import ( + create_iter_data_given_by, + reformat_hist_y_given_by, +) from pandas.plotting._matplotlib.tools import ( create_subplots, flatten_axes, @@ -43,19 +49,30 @@ def __init__(self, data, bins=10, bottom=0, **kwargs): MPLPlot.__init__(self, data, **kwargs) def _args_adjust(self): - if is_integer(self.bins): - # create common bin edge - values = self.data._convert(datetime=True)._get_numeric_data() - values = np.ravel(values) - values = values[~isna(values)] - _, self.bins = np.histogram( - values, bins=self.bins, range=self.kwds.get("range", None) - ) + # calculate bin number separately in different subplots + # where subplots are created based on by argument + if is_integer(self.bins): + if self.by is not None: + grouped = self.data.groupby(self.by)[self.columns] + self.bins = [self._calculate_bins(group) for key, group in grouped] + else: + self.bins = self._calculate_bins(self.data) if is_list_like(self.bottom): self.bottom = np.array(self.bottom) + def _calculate_bins(self, data: DataFrame) -> np.ndarray: + """Calculate bins given data""" + values = data._convert(datetime=True)._get_numeric_data() + values = np.ravel(values) + values = values[~isna(values)] + + hist, bins = np.histogram( + values, bins=self.bins, range=self.kwds.get("range", None) + ) + return bins + @classmethod def _plot( cls, @@ -70,7 +87,6 @@ def _plot( ): if column_num == 0: cls._initialize_stacker(ax, stacking_id, len(bins) - 1) - y = y[~isna(y)] base = np.zeros(len(bins) - 1) bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"]) @@ -83,7 +99,14 @@ def _make_plot(self): colors = self._get_colors() stacking_id = self._get_stacking_id() - for i, (label, y) in enumerate(self._iter_data()): + # Re-create iterated data if `by` is assigned by users + data = ( + create_iter_data_given_by(self.data, self._kind) + if self.by is not None + else self.data + ) + + for i, (label, y) in enumerate(self._iter_data(data=data)): ax = self._get_ax(i) kwds = self.kwds.copy() @@ -98,6 +121,15 @@ def _make_plot(self): kwds = self._make_plot_keywords(kwds, y) + # the bins is multi-dimension array now and each plot need only 1-d and + # when by is applied, label should be columns that are grouped + if self.by is not None: + kwds["bins"] = kwds["bins"][i] + kwds["label"] = self.columns + kwds.pop("color") + + y = reformat_hist_y_given_by(y, self.by) + # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array, # and each sub-array (10,) will be called in each iteration. If users only # provide 1D array, we assume the same weights is used for all iterations @@ -106,6 +138,11 @@ def _make_plot(self): kwds["weights"] = weights[:, i] artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds) + + # when by is applied, show title for subplots to know which group it is + if self.by is not None: + ax.set_title(pprint_thing(label)) + self._append_legend_handles_labels(artists[0], label) def _make_plot_keywords(self, kwds, y): diff --git a/pandas/tests/plotting/frame/test_hist_box_by.py b/pandas/tests/plotting/frame/test_hist_box_by.py new file mode 100644 index 0000000000000..ba6d232733762 --- /dev/null +++ b/pandas/tests/plotting/frame/test_hist_box_by.py @@ -0,0 +1,389 @@ +import re + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import DataFrame +import pandas._testing as tm +from pandas.tests.plotting.common import ( + TestPlotBase, + _check_plot_works, +) + + +def _create_hist_box_with_by_df(): + np.random.seed(0) + df = DataFrame(np.random.randn(30, 2), columns=["A", "B"]) + df["C"] = np.random.choice(["a", "b", "c"], 30) + df["D"] = np.random.choice(["a", "b", "c"], 30) + return df + + +@td.skip_if_no_mpl +class TestHistWithBy(TestPlotBase): + def setup_method(self, method): + TestPlotBase.setup_method(self, method) + import matplotlib as mpl + + mpl.rcdefaults() + self.hist_df = _create_hist_box_with_by_df() + + @pytest.mark.parametrize( + "by, column, titles, legends", + [ + ("C", "A", ["a", "b", "c"], [["A"]] * 3), + ("C", ["A", "B"], ["a", "b", "c"], [["A", "B"]] * 3), + ("C", None, ["a", "b", "c"], [["A", "B"]] * 3), + ( + ["C", "D"], + "A", + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ], + [["A"]] * 9, + ), + ( + ["C", "D"], + ["A", "B"], + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ], + [["A", "B"]] * 9, + ), + ( + ["C", "D"], + None, + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ], + [["A", "B"]] * 9, + ), + ], + ) + def test_hist_plot_by_argument(self, by, column, titles, legends): + # GH 15079 + axes = _check_plot_works(self.hist_df.plot.hist, column=column, by=by) + result_titles = [ax.get_title() for ax in axes] + result_legends = [ + [legend.get_text() for legend in ax.get_legend().texts] for ax in axes + ] + + assert result_legends == legends + assert result_titles == titles + + @pytest.mark.parametrize( + "by, column, titles, legends", + [ + (0, "A", ["a", "b", "c"], [["A"]] * 3), + (0, None, ["a", "b", "c"], [["A", "B"]] * 3), + ( + [0, "D"], + "A", + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ], + [["A"]] * 9, + ), + ], + ) + def test_hist_plot_by_0(self, by, column, titles, legends): + # GH 15079 + df = self.hist_df.copy() + df = df.rename(columns={"C": 0}) + + axes = _check_plot_works(df.plot.hist, column=column, by=by) + result_titles = [ax.get_title() for ax in axes] + result_legends = [ + [legend.get_text() for legend in ax.get_legend().texts] for ax in axes + ] + + assert result_legends == legends + assert result_titles == titles + + @pytest.mark.parametrize( + "by, column", + [ + ([], ["A"]), + ([], ["A", "B"]), + ((), None), + ((), ["A", "B"]), + ], + ) + def test_hist_plot_empty_list_string_tuple_by(self, by, column): + # GH 15079 + msg = "No group keys passed" + with pytest.raises(ValueError, match=msg): + _check_plot_works(self.hist_df.plot.hist, column=column, by=by) + + @pytest.mark.slow + @pytest.mark.parametrize( + "by, column, layout, axes_num", + [ + (["C"], "A", (2, 2), 3), + ("C", "A", (2, 2), 3), + (["C"], ["A"], (1, 3), 3), + ("C", None, (3, 1), 3), + ("C", ["A", "B"], (3, 1), 3), + (["C", "D"], "A", (9, 1), 9), + (["C", "D"], "A", (3, 3), 9), + (["C", "D"], ["A"], (5, 2), 9), + (["C", "D"], ["A", "B"], (9, 1), 9), + (["C", "D"], None, (9, 1), 9), + (["C", "D"], ["A", "B"], (5, 2), 9), + ], + ) + def test_hist_plot_layout_with_by(self, by, column, layout, axes_num): + # GH 15079 + # _check_plot_works adds an ax so catch warning. see GH #13188 + with tm.assert_produces_warning(UserWarning): + axes = _check_plot_works( + self.hist_df.plot.hist, column=column, by=by, layout=layout + ) + self._check_axes_shape(axes, axes_num=axes_num, layout=layout) + + @pytest.mark.parametrize( + "msg, by, layout", + [ + ("larger than required size", ["C", "D"], (1, 1)), + (re.escape("Layout must be a tuple of (rows, columns)"), "C", (1,)), + ("At least one dimension of layout must be positive", "C", (-1, -1)), + ], + ) + def test_hist_plot_invalid_layout_with_by_raises(self, msg, by, layout): + # GH 15079, test if error is raised when invalid layout is given + + with pytest.raises(ValueError, match=msg): + self.hist_df.plot.hist(column=["A", "B"], by=by, layout=layout) + + @pytest.mark.slow + def test_axis_share_x_with_by(self): + # GH 15079 + ax1, ax2, ax3 = self.hist_df.plot.hist(column="A", by="C", sharex=True) + + # share x + assert ax1._shared_x_axes.joined(ax1, ax2) + assert ax2._shared_x_axes.joined(ax1, ax2) + assert ax3._shared_x_axes.joined(ax1, ax3) + assert ax3._shared_x_axes.joined(ax2, ax3) + + # don't share y + assert not ax1._shared_y_axes.joined(ax1, ax2) + assert not ax2._shared_y_axes.joined(ax1, ax2) + assert not ax3._shared_y_axes.joined(ax1, ax3) + assert not ax3._shared_y_axes.joined(ax2, ax3) + + @pytest.mark.slow + def test_axis_share_y_with_by(self): + # GH 15079 + ax1, ax2, ax3 = self.hist_df.plot.hist(column="A", by="C", sharey=True) + + # share y + assert ax1._shared_y_axes.joined(ax1, ax2) + assert ax2._shared_y_axes.joined(ax1, ax2) + assert ax3._shared_y_axes.joined(ax1, ax3) + assert ax3._shared_y_axes.joined(ax2, ax3) + + # don't share x + assert not ax1._shared_x_axes.joined(ax1, ax2) + assert not ax2._shared_x_axes.joined(ax1, ax2) + assert not ax3._shared_x_axes.joined(ax1, ax3) + assert not ax3._shared_x_axes.joined(ax2, ax3) + + @pytest.mark.parametrize("figsize", [(12, 8), (20, 10)]) + def test_figure_shape_hist_with_by(self, figsize): + # GH 15079 + axes = self.hist_df.plot.hist(column="A", by="C", figsize=figsize) + self._check_axes_shape(axes, axes_num=3, figsize=figsize) + + +@td.skip_if_no_mpl +class TestBoxWithBy(TestPlotBase): + def setup_method(self, method): + TestPlotBase.setup_method(self, method) + import matplotlib as mpl + + mpl.rcdefaults() + self.box_df = _create_hist_box_with_by_df() + + @pytest.mark.parametrize( + "by, column, titles, xticklabels", + [ + ("C", "A", ["A"], [["a", "b", "c"]]), + ( + ["C", "D"], + "A", + ["A"], + [ + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ] + ], + ), + ("C", ["A", "B"], ["A", "B"], [["a", "b", "c"]] * 2), + ( + ["C", "D"], + ["A", "B"], + ["A", "B"], + [ + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ] + ] + * 2, + ), + (["C"], None, ["A", "B"], [["a", "b", "c"]] * 2), + ], + ) + def test_box_plot_by_argument(self, by, column, titles, xticklabels): + # GH 15079 + axes = _check_plot_works(self.box_df.plot.box, column=column, by=by) + result_titles = [ax.get_title() for ax in axes] + result_xticklabels = [ + [label.get_text() for label in ax.get_xticklabels()] for ax in axes + ] + + assert result_xticklabels == xticklabels + assert result_titles == titles + + @pytest.mark.parametrize( + "by, column, titles, xticklabels", + [ + (0, "A", ["A"], [["a", "b", "c"]]), + ( + [0, "D"], + "A", + ["A"], + [ + [ + "(a, a)", + "(a, b)", + "(a, c)", + "(b, a)", + "(b, b)", + "(b, c)", + "(c, a)", + "(c, b)", + "(c, c)", + ] + ], + ), + (0, None, ["A", "B"], [["a", "b", "c"]] * 2), + ], + ) + def test_box_plot_by_0(self, by, column, titles, xticklabels): + # GH 15079 + df = self.box_df.copy() + df = df.rename(columns={"C": 0}) + + axes = _check_plot_works(df.plot.box, column=column, by=by) + result_titles = [ax.get_title() for ax in axes] + result_xticklabels = [ + [label.get_text() for label in ax.get_xticklabels()] for ax in axes + ] + + assert result_xticklabels == xticklabels + assert result_titles == titles + + @pytest.mark.parametrize( + "by, column", + [ + ([], ["A"]), + ((), "A"), + ([], None), + ((), ["A", "B"]), + ], + ) + def test_box_plot_with_none_empty_list_by(self, by, column): + # GH 15079 + msg = "No group keys passed" + with pytest.raises(ValueError, match=msg): + _check_plot_works(self.box_df.plot.box, column=column, by=by) + + @pytest.mark.slow + @pytest.mark.parametrize( + "by, column, layout, axes_num", + [ + (["C"], "A", (1, 1), 1), + ("C", "A", (1, 1), 1), + ("C", None, (2, 1), 2), + ("C", ["A", "B"], (1, 2), 2), + (["C", "D"], "A", (1, 1), 1), + (["C", "D"], None, (1, 2), 2), + ], + ) + def test_box_plot_layout_with_by(self, by, column, layout, axes_num): + # GH 15079 + axes = _check_plot_works( + self.box_df.plot.box, column=column, by=by, layout=layout + ) + self._check_axes_shape(axes, axes_num=axes_num, layout=layout) + + @pytest.mark.parametrize( + "msg, by, layout", + [ + ("larger than required size", ["C", "D"], (1, 1)), + (re.escape("Layout must be a tuple of (rows, columns)"), "C", (1,)), + ("At least one dimension of layout must be positive", "C", (-1, -1)), + ], + ) + def test_box_plot_invalid_layout_with_by_raises(self, msg, by, layout): + # GH 15079, test if error is raised when invalid layout is given + + with pytest.raises(ValueError, match=msg): + self.box_df.plot.box(column=["A", "B"], by=by, layout=layout) + + @pytest.mark.parametrize("figsize", [(12, 8), (20, 10)]) + def test_figure_shape_hist_with_by(self, figsize): + # GH 15079 + axes = self.box_df.plot.box(column="A", by="C", figsize=figsize) + self._check_axes_shape(axes, axes_num=1, figsize=figsize)