Skip to content

Commit b5c5442

Browse files
committed
Merge pull request #4024 from danbirken/secondary_y_for_bar_plot
BUG: Make secondary_y work properly for bar plots GH3598
2 parents bc038e7 + 77ae62f commit b5c5442

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

doc/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ pandas 0.11.1
232232
is a ``list`` or ``tuple``.
233233
- Fixed bug where a time-series was being selected in preference to an actual column name
234234
in a frame (:issue:`3594`)
235+
- Make secondary_y work properly for bar plots (:issue:`3598`)
235236
- Fix modulo and integer division on Series,DataFrames to act similary to ``float`` dtypes to return
236237
``np.nan`` or ``np.inf`` as appropriate (:issue:`3590`)
237238
- Fix incorrect dtype on groupby with ``as_index=False`` (:issue:`3610`)

pandas/tools/plotting.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,11 @@ def _maybe_add_color(self, colors, kwds, style, i):
10841084
if has_color and (style is None or re.match('[a-z]+', style) is None):
10851085
kwds['color'] = colors[i % len(colors)]
10861086

1087+
def _get_marked_label(self, label, col_num):
1088+
if self.on_right(col_num):
1089+
return label + ' (right)'
1090+
else:
1091+
return label
10871092

10881093
class KdePlot(MPLPlot):
10891094
def __init__(self, data, **kwargs):
@@ -1214,10 +1219,12 @@ def _make_plot(self):
12141219

12151220
newline = plotf(*args, **kwds)[0]
12161221
lines.append(newline)
1217-
leg_label = label
1218-
if self.mark_right and self.on_right(i):
1219-
leg_label += ' (right)'
1220-
labels.append(leg_label)
1222+
1223+
if self.mark_right:
1224+
labels.append(self._get_marked_label(label, i))
1225+
else:
1226+
labels.append(label)
1227+
12211228
ax.grid(self.grid)
12221229

12231230
if self._is_datetype():
@@ -1235,18 +1242,16 @@ def _make_ts_plot(self, data, **kwargs):
12351242
lines = []
12361243
labels = []
12371244

1238-
def to_leg_label(label, i):
1239-
if self.mark_right and self.on_right(i):
1240-
return label + ' (right)'
1241-
return label
1242-
12431245
def _plot(data, col_num, ax, label, style, **kwds):
12441246
newlines = tsplot(data, plotf, ax=ax, label=label,
12451247
style=style, **kwds)
12461248
ax.grid(self.grid)
12471249
lines.append(newlines[0])
1248-
leg_label = to_leg_label(label, col_num)
1249-
labels.append(leg_label)
1250+
1251+
if self.mark_right:
1252+
labels.append(self._get_marked_label(label, col_num))
1253+
else:
1254+
labels.append(label)
12501255

12511256
if isinstance(data, Series):
12521257
ax = self._get_ax(0) # self.axes[0]
@@ -1356,6 +1361,7 @@ class BarPlot(MPLPlot):
13561361
_default_rot = {'bar': 90, 'barh': 0}
13571362

13581363
def __init__(self, data, **kwargs):
1364+
self.mark_right = kwargs.pop('mark_right', True)
13591365
self.stacked = kwargs.pop('stacked', False)
13601366
self.ax_pos = np.arange(len(data)) + 0.25
13611367
if self.stacked:
@@ -1398,15 +1404,14 @@ def _make_plot(self):
13981404
rects = []
13991405
labels = []
14001406

1401-
ax = self._get_ax(0) # self.axes[0]
1402-
14031407
bar_f = self.bar_f
14041408

14051409
pos_prior = neg_prior = np.zeros(len(self.data))
14061410

14071411
K = self.nseries
14081412

14091413
for i, (label, y) in enumerate(self._iter_data()):
1414+
ax = self._get_ax(i)
14101415
label = com.pprint_thing(label)
14111416
kwds = self.kwds.copy()
14121417
kwds['color'] = colors[i % len(colors)]
@@ -1419,8 +1424,6 @@ def _make_plot(self):
14191424
start = 0 if mpl.__version__ == "1.2.1" else None
14201425

14211426
if self.subplots:
1422-
ax = self._get_ax(i) # self.axes[i]
1423-
14241427
rect = bar_f(ax, self.ax_pos, y, self.bar_width,
14251428
start = start,
14261429
**kwds)
@@ -1437,7 +1440,10 @@ def _make_plot(self):
14371440
start = start,
14381441
label=label, **kwds)
14391442
rects.append(rect)
1440-
labels.append(label)
1443+
if self.mark_right:
1444+
labels.append(self._get_marked_label(label, i))
1445+
else:
1446+
labels.append(label)
14411447

14421448
if self.legend and not self.subplots:
14431449
patches = [r[0] for r in rects]
@@ -1537,7 +1543,10 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
15371543
Rotation for ticks
15381544
secondary_y : boolean or sequence, default False
15391545
Whether to plot on the secondary y-axis
1540-
If dict then can select which columns to plot on secondary y-axis
1546+
If a list/tuple, which columns to plot on secondary y-axis
1547+
mark_right: boolean, default True
1548+
When using a secondary_y axis, should the legend label the axis of
1549+
the various columns automatically
15411550
kwds : keywords
15421551
Options to pass to matplotlib plotting method
15431552

pandas/tseries/tests/test_plotting.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,16 @@ def test_secondary_frame(self):
615615
self.assert_(axes[1].get_yaxis().get_ticks_position() == 'default')
616616
self.assert_(axes[2].get_yaxis().get_ticks_position() == 'right')
617617

618+
@slow
619+
def test_secondary_bar_frame(self):
620+
import matplotlib.pyplot as plt
621+
plt.close('all')
622+
df = DataFrame(np.random.randn(5, 3), columns=['a', 'b', 'c'])
623+
axes = df.plot(kind='bar', secondary_y=['a', 'c'], subplots=True)
624+
self.assert_(axes[0].get_yaxis().get_ticks_position() == 'right')
625+
self.assert_(axes[1].get_yaxis().get_ticks_position() == 'default')
626+
self.assert_(axes[2].get_yaxis().get_ticks_position() == 'right')
627+
618628
@slow
619629
def test_mixed_freq_regular_first(self):
620630
import matplotlib.pyplot as plt
@@ -864,6 +874,18 @@ def test_secondary_legend(self):
864874
self.assert_(leg.get_texts()[2].get_text() == 'C')
865875
self.assert_(leg.get_texts()[3].get_text() == 'D')
866876

877+
plt.clf()
878+
ax = df.plot(kind='bar', secondary_y=['A'])
879+
leg = ax.get_legend()
880+
self.assert_(leg.get_texts()[0].get_text() == 'A (right)')
881+
self.assert_(leg.get_texts()[1].get_text() == 'B')
882+
883+
plt.clf()
884+
ax = df.plot(kind='bar', secondary_y=['A'], mark_right=False)
885+
leg = ax.get_legend()
886+
self.assert_(leg.get_texts()[0].get_text() == 'A')
887+
self.assert_(leg.get_texts()[1].get_text() == 'B')
888+
867889
plt.clf()
868890
ax = fig.add_subplot(211)
869891
df = tm.makeTimeDataFrame()

0 commit comments

Comments
 (0)