From 78cd15b502ff751a9de38304b795daa15b556d21 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Fri, 21 Mar 2014 14:25:28 +0900 Subject: [PATCH] BUG: legend behaves inconsistently when plotting to the same axes repeatedly --- doc/source/release.rst | 1 + doc/source/visualization.rst | 6 +- pandas/tests/test_graphics.py | 115 ++++++++++++++++++---- pandas/tools/plotting.py | 173 +++++++++++++++------------------- 4 files changed, 177 insertions(+), 118 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index 03b89f9077994..a632d69eef734 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -384,6 +384,7 @@ Bug Fixes group match wasn't renamed to the group name - Bug in ``DataFrame.to_csv`` where setting `index` to `False` ignored the `header` kwarg (:issue:`6186`) +- Bug in `DataFrame.plot` and `Series.plot` legend behave inconsistently when plotting to the same axes repeatedly (:issue:`6678`) pandas 0.13.1 ------------- diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index f42d2b3b52f55..15d05ff046bb1 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -66,7 +66,7 @@ for controlling the look of the plot: .. ipython:: python @savefig series_plot_basic2.png - plt.figure(); ts.plot(style='k--', label='Series'); plt.legend() + plt.figure(); ts.plot(style='k--', label='Series'); On DataFrame, ``plot`` is a convenience to plot all of the columns with labels: @@ -76,7 +76,7 @@ On DataFrame, ``plot`` is a convenience to plot all of the columns with labels: df = df.cumsum() @savefig frame_plot_basic.png - plt.figure(); df.plot(); plt.legend(loc='best') + plt.figure(); df.plot(); You may set the ``legend`` argument to ``False`` to hide the legend, which is shown by default. @@ -91,7 +91,7 @@ Some other options are available, like plotting each Series on a different axis: .. ipython:: python @savefig frame_plot_subplots.png - df.plot(subplots=True, figsize=(6, 6)); plt.legend(loc='best') + df.plot(subplots=True, figsize=(6, 6)); You may pass ``logy`` to get a log-scale Y axis. diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 5beb5a05a800d..fceec8cf00e92 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -490,29 +490,34 @@ def test_subplots(self): df = DataFrame(np.random.rand(10, 3), index=list(string.ascii_letters[:10])) - axes = df.plot(subplots=True, sharex=True, legend=True) + for kind in ['bar', 'barh', 'line']: + axes = df.plot(kind=kind, subplots=True, sharex=True, legend=True) - for ax in axes: - self.assertIsNotNone(ax.get_legend()) - - axes = df.plot(subplots=True, sharex=True) - for ax in axes[:-2]: - [self.assert_(not label.get_visible()) - for label in ax.get_xticklabels()] - [self.assert_(label.get_visible()) - for label in ax.get_yticklabels()] + for ax, column in zip(axes, df.columns): + self._check_legend_labels(ax, [column]) - [self.assert_(label.get_visible()) - for label in axes[-1].get_xticklabels()] - [self.assert_(label.get_visible()) - for label in axes[-1].get_yticklabels()] + axes = df.plot(kind=kind, subplots=True, sharex=True) + for ax in axes[:-2]: + [self.assert_(not label.get_visible()) + for label in ax.get_xticklabels()] + [self.assert_(label.get_visible()) + for label in ax.get_yticklabels()] - axes = df.plot(subplots=True, sharex=False) - for ax in axes: [self.assert_(label.get_visible()) - for label in ax.get_xticklabels()] + for label in axes[-1].get_xticklabels()] [self.assert_(label.get_visible()) - for label in ax.get_yticklabels()] + for label in axes[-1].get_yticklabels()] + + axes = df.plot(kind=kind, subplots=True, sharex=False) + for ax in axes: + [self.assert_(label.get_visible()) + for label in ax.get_xticklabels()] + [self.assert_(label.get_visible()) + for label in ax.get_yticklabels()] + + axes = df.plot(kind=kind, subplots=True, legend=False) + for ax in axes: + self.assertTrue(ax.get_legend() is None) @slow def test_bar_colors(self): @@ -873,7 +878,7 @@ def test_kde(self): _check_plot_works(df.plot, kind='kde') _check_plot_works(df.plot, kind='kde', subplots=True) ax = df.plot(kind='kde') - self.assertIsNotNone(ax.get_legend()) + self._check_legend_labels(ax, df.columns) axes = df.plot(kind='kde', logy=True, subplots=True) for ax in axes: self.assertEqual(ax.get_yscale(), 'log') @@ -1046,6 +1051,64 @@ def test_plot_int_columns(self): df = DataFrame(randn(100, 4)).cumsum() _check_plot_works(df.plot, legend=True) + def _check_legend_labels(self, ax, labels): + import pandas.core.common as com + labels = [com.pprint_thing(l) for l in labels] + self.assertTrue(ax.get_legend() is not None) + legend_labels = [t.get_text() for t in ax.get_legend().get_texts()] + self.assertEqual(labels, legend_labels) + + @slow + def test_df_legend_labels(self): + kinds = 'line', 'bar', 'barh', 'kde', 'density' + df = DataFrame(randn(3, 3), columns=['a', 'b', 'c']) + df2 = DataFrame(randn(3, 3), columns=['d', 'e', 'f']) + df3 = DataFrame(randn(3, 3), columns=['g', 'h', 'i']) + df4 = DataFrame(randn(3, 3), columns=['j', 'k', 'l']) + + for kind in kinds: + ax = df.plot(kind=kind, legend=True) + self._check_legend_labels(ax, df.columns) + + ax = df2.plot(kind=kind, legend=False, ax=ax) + self._check_legend_labels(ax, df.columns) + + ax = df3.plot(kind=kind, legend=True, ax=ax) + self._check_legend_labels(ax, df.columns + df3.columns) + + ax = df4.plot(kind=kind, legend='reverse', ax=ax) + expected = list(df.columns + df3.columns) + list(reversed(df4.columns)) + self._check_legend_labels(ax, expected) + + # Secondary Y + ax = df.plot(legend=True, secondary_y='b') + self._check_legend_labels(ax, ['a', 'b (right)', 'c']) + ax = df2.plot(legend=False, ax=ax) + self._check_legend_labels(ax, ['a', 'b (right)', 'c']) + ax = df3.plot(kind='bar', legend=True, secondary_y='h', ax=ax) + self._check_legend_labels(ax, ['a', 'b (right)', 'c', 'g', 'h (right)', 'i']) + + # Time Series + ind = date_range('1/1/2014', periods=3) + df = DataFrame(randn(3, 3), columns=['a', 'b', 'c'], index=ind) + df2 = DataFrame(randn(3, 3), columns=['d', 'e', 'f'], index=ind) + df3 = DataFrame(randn(3, 3), columns=['g', 'h', 'i'], index=ind) + ax = df.plot(legend=True, secondary_y='b') + self._check_legend_labels(ax, ['a', 'b (right)', 'c']) + ax = df2.plot(legend=False, ax=ax) + self._check_legend_labels(ax, ['a', 'b (right)', 'c']) + ax = df3.plot(legend=True, ax=ax) + self._check_legend_labels(ax, ['a', 'b (right)', 'c', 'g', 'h', 'i']) + + # scatter + ax = df.plot(kind='scatter', x='a', y='b', label='data1') + self._check_legend_labels(ax, ['data1']) + ax = df2.plot(kind='scatter', x='d', y='e', legend=False, + label='data2', ax=ax) + self._check_legend_labels(ax, ['data1']) + ax = df3.plot(kind='scatter', x='g', y='h', label='data3', ax=ax) + self._check_legend_labels(ax, ['data1', 'data3']) + def test_legend_name(self): multi = DataFrame(randn(4, 4), columns=[np.array(['a', 'a', 'b', 'b']), @@ -1056,6 +1119,20 @@ def test_legend_name(self): leg_title = ax.legend_.get_title() self.assertEqual(leg_title.get_text(), 'group,individual') + df = DataFrame(randn(5, 5)) + ax = df.plot(legend=True, ax=ax) + leg_title = ax.legend_.get_title() + self.assertEqual(leg_title.get_text(), 'group,individual') + + df.columns.name = 'new' + ax = df.plot(legend=False, ax=ax) + leg_title = ax.legend_.get_title() + self.assertEqual(leg_title.get_text(), 'group,individual') + + ax = df.plot(legend=True, ax=ax) + leg_title = ax.legend_.get_title() + self.assertEqual(leg_title.get_text(), 'new') + def _check_plot_fails(self, f, *args, **kwargs): with tm.assertRaises(Exception): f(*args, **kwargs) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 42135e2186468..7e67c48572f51 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -784,8 +784,10 @@ class MPLPlot(object): """ _default_rot = 0 - _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog'] - _attr_defaults = {'logy': False, 'logx': False, 'loglog': False} + _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog', + 'mark_right'] + _attr_defaults = {'logy': False, 'logx': False, 'loglog': False, + 'mark_right': True} def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, sharey=False, use_index=True, @@ -823,6 +825,8 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, self.grid = grid self.legend = legend + self.legend_handles = [] + self.legend_labels = [] for attr in self._pop_attributes: value = kwds.pop(attr, self._attr_defaults.get(attr, None)) @@ -919,6 +923,7 @@ def generate(self): self._setup_subplots() self._make_plot() self._add_table() + self._make_legend() self._post_plot_logic() self._adorn_subplots() @@ -1077,6 +1082,57 @@ def legend_title(self): else: return None + def _add_legend_handle(self, handle, label, index=None): + if not label is None: + if self.mark_right and index is not None: + if self.on_right(index): + label = label + ' (right)' + self.legend_handles.append(handle) + self.legend_labels.append(label) + + def _make_legend(self): + ax, leg = self._get_ax_legend(self.axes[0]) + + handles = [] + labels = [] + title = '' + + if not self.subplots: + if not leg is None: + title = leg.get_title().get_text() + handles = leg.legendHandles + labels = [x.get_text() for x in leg.get_texts()] + + if self.legend: + if self.legend == 'reverse': + self.legend_handles = reversed(self.legend_handles) + self.legend_labels = reversed(self.legend_labels) + + handles += self.legend_handles + labels += self.legend_labels + if not self.legend_title is None: + title = self.legend_title + + if len(handles) > 0: + ax.legend(handles, labels, loc='best', title=title) + + elif self.subplots and self.legend: + for ax in self.axes: + ax.legend(loc='best') + + + def _get_ax_legend(self, ax): + leg = ax.get_legend() + other_ax = (getattr(ax, 'right_ax', None) or + getattr(ax, 'left_ax', None)) + other_leg = None + if other_ax is not None: + other_leg = other_ax.get_legend() + if leg is None and other_leg is not None: + leg = other_leg + ax = other_ax + return ax, leg + @cache_readonly def plt(self): import matplotlib.pyplot as plt @@ -1205,12 +1261,6 @@ def _maybe_add_color(self, colors, kwds, style, i): if has_color and (style is None or re.match('[a-z]+', style) is None): kwds['color'] = colors[i % len(colors)] - def _get_marked_label(self, label, col_num): - if self.on_right(col_num): - return label + ' (right)' - else: - return label - def _parse_errorbars(self, error_dim='y', **kwds): ''' Look for error keyword arguments and return the actual errorbar data @@ -1330,22 +1380,9 @@ def _make_plot(self): else: args = (ax, ind, y, style) - plotf(*args, **kwds) - ax.grid(self.grid) + newlines = plotf(*args, **kwds) + self._add_legend_handle(newlines[0], label) - def _post_plot_logic(self): - if self.legend: - for ax in self.axes: - ax.legend(loc='best') - leg = self.axes[0].get_legend() - if leg is not None: - lines = leg.get_lines() - labels = [x.get_text() for x in leg.get_texts()] - - if self.legend == 'reverse': - lines = reversed(lines) - labels = reversed(labels) - ax.legend(lines, labels, loc='best', title=self.legend_title) class ScatterPlot(MPLPlot): def __init__(self, data, x, y, **kwargs): @@ -1364,7 +1401,15 @@ def __init__(self, data, x, y, **kwargs): def _make_plot(self): x, y, data = self.x, self.y, self.data ax = self.axes[0] - ax.scatter(data[x].values, data[y].values, **self.kwds) + + if self.legend and hasattr(self, 'label'): + label = self.label + else: + label = None + scatter = ax.scatter(data[x].values, data[y].values, label=label, + **self.kwds) + + self._add_legend_handle(scatter, label) def _post_plot_logic(self): ax = self.axes[0] @@ -1422,7 +1467,6 @@ def _post_plot_logic(self): class LinePlot(MPLPlot): def __init__(self, data, **kwargs): - self.mark_right = kwargs.pop('mark_right', True) MPLPlot.__init__(self, data, **kwargs) self.x_compat = plot_params['x_compat'] if 'x_compat' in self.kwds: @@ -1483,7 +1527,6 @@ def _make_plot(self): else: from pandas.core.frame import DataFrame lines = [] - labels = [] x = self._get_xticks(convert_period=True) plotf = self._get_plot_function() @@ -1519,22 +1562,16 @@ def _make_plot(self): else: args = (ax, x, y) - newline = plotf(*args, **kwds)[0] - lines.append(newline) + newlines = plotf(*args, **kwds) - if self.mark_right: - labels.append(self._get_marked_label(label, i)) - else: - labels.append(label) + self._add_legend_handle(newlines[0], label, index=i) - ax.grid(self.grid) + lines.append(newlines[0]) if self._is_datetype(): left, right = _get_xlim(lines) ax.set_xlim(left, right) - self._make_legend(lines, labels) - def _make_ts_plot(self, data, **kwargs): from pandas.tseries.plotting import tsplot from pandas.core.frame import DataFrame @@ -1543,8 +1580,6 @@ def _make_ts_plot(self, data, **kwargs): colors = self._get_colors() plotf = self._get_plot_function() - lines = [] - labels = [] def _plot(data, col_num, ax, label, style, **kwds): @@ -1556,13 +1591,7 @@ def _plot(data, col_num, ax, label, style, **kwds): newlines = tsplot(data, plotf, ax=ax, label=label, **kwds) - ax.grid(self.grid) - lines.append(newlines[0]) - - if self.mark_right: - labels.append(self._get_marked_label(label, col_num)) - else: - labels.append(label) + self._add_legend_handle(newlines[0], label, index=col_num) if isinstance(data, Series): ax = self._get_ax(0) # self.axes[0] @@ -1597,37 +1626,6 @@ def _plot(data, col_num, ax, label, style, **kwds): _plot(data[col], i, ax, label, style, **kwds) - self._make_legend(lines, labels) - - def _make_legend(self, lines, labels): - ax, leg = self._get_ax_legend(self.axes[0]) - - if not self.subplots: - if leg is not None: - ext_lines = leg.get_lines() - ext_labels = [x.get_text() for x in leg.get_texts()] - ext_lines.extend(lines) - ext_labels.extend(labels) - ax.legend(ext_lines, ext_labels, loc='best', - title=self.legend_title) - elif self.legend: - if self.legend == 'reverse': - lines = reversed(lines) - labels = reversed(labels) - ax.legend(lines, labels, loc='best', title=self.legend_title) - - def _get_ax_legend(self, ax): - leg = ax.get_legend() - other_ax = (getattr(ax, 'right_ax', None) or - getattr(ax, 'left_ax', None)) - other_leg = None - if other_ax is not None: - other_leg = other_ax.get_legend() - if leg is None and other_leg is not None: - leg = other_leg - ax = other_ax - return ax, leg - def _maybe_convert_index(self, data): # tsplot converts automatically, but don't want to convert index # over and over for DataFrames @@ -1679,16 +1677,12 @@ def _post_plot_logic(self): if index_name is not None: ax.set_xlabel(index_name) - if self.subplots and self.legend: - for ax in self.axes: - ax.legend(loc='best') class BarPlot(MPLPlot): _default_rot = {'bar': 90, 'barh': 0} def __init__(self, data, **kwargs): - self.mark_right = kwargs.pop('mark_right', True) self.stacked = kwargs.pop('stacked', False) self.bar_width = kwargs.pop('width', 0.5) @@ -1739,8 +1733,6 @@ def _make_plot(self): colors = self._get_colors() ncolors = len(colors) - rects = [] - labels = [] bar_f = self.bar_f @@ -1778,8 +1770,8 @@ def _make_plot(self): if self.subplots: w = self.bar_width / 2 - rect = bar_f(ax, self.ax_pos + w, y, self.bar_width, - start=start, **kwds) + rect = bar_f(ax, self.ax_pos + w, y, self.bar_width, + start=start, label=label, **kwds) ax.set_title(label) elif self.stacked: mask = y > 0 @@ -1793,19 +1785,8 @@ def _make_plot(self): w = self.bar_width / K rect = bar_f(ax, self.ax_pos + (i + 1.5) * w, y, w, start=start, label=label, **kwds) - rects.append(rect) - if self.mark_right: - labels.append(self._get_marked_label(label, i)) - else: - labels.append(label) - - if self.legend and not self.subplots: - patches = [r[0] for r in rects] - if self.legend == 'reverse': - patches = reversed(patches) - labels = reversed(labels) - self.axes[0].legend(patches, labels, loc='best', - title=self.legend_title) + + self._add_legend_handle(rect, label, index=i) def _post_plot_logic(self): for ax in self.axes: