diff --git a/pandas/tests/plotting/common.py b/pandas/tests/plotting/common.py index 1f94e18d8e622..d31f57426a721 100644 --- a/pandas/tests/plotting/common.py +++ b/pandas/tests/plotting/common.py @@ -626,7 +626,8 @@ def _gen_two_subplots(f, fig, **kwargs): """ Create plot on two subplots forcefully created. """ - kwargs.get("ax", fig.add_subplot(211)) + if "ax" not in kwargs: + fig.add_subplot(211) yield f(**kwargs) if f is pd.plotting.bootstrap_plot: diff --git a/pandas/tests/plotting/test_common.py b/pandas/tests/plotting/test_common.py index 2664dc8e1b090..4674fc1bb2c18 100644 --- a/pandas/tests/plotting/test_common.py +++ b/pandas/tests/plotting/test_common.py @@ -3,7 +3,11 @@ import pandas.util._test_decorators as td from pandas import DataFrame -from pandas.tests.plotting.common import TestPlotBase, _check_plot_works +from pandas.tests.plotting.common import ( + TestPlotBase, + _check_plot_works, + _gen_two_subplots, +) pytestmark = pytest.mark.slow @@ -24,3 +28,15 @@ def test__check_ticks_props(self): self._check_ticks_props(ax, yrot=0) with pytest.raises(AssertionError, match=msg): self._check_ticks_props(ax, ylabelsize=0) + + def test__gen_two_subplots_with_ax(self): + fig = self.plt.gcf() + gen = _gen_two_subplots(f=lambda **kwargs: None, fig=fig, ax="test") + # On the first yield, no subplot should be added since ax was passed + next(gen) + assert fig.get_axes() == [] + # On the second, the one axis should match fig.subplot(2, 1, 2) + next(gen) + axes = fig.get_axes() + assert len(axes) == 1 + assert axes[0].get_geometry() == (2, 1, 2)