Skip to content

Commit 5f49373

Browse files
committed
Make plotf functions stateless
1 parent 983c579 commit 5f49373

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

pandas/tools/plotting.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,7 +1594,6 @@ def _is_ts_plot(self):
15941594
return not self.x_compat and self.use_index and self._use_dynamic_x()
15951595

15961596
def _make_plot(self):
1597-
self._initialize_prior(len(self.data))
15981597

15991598
if self._is_ts_plot():
16001599
data = self._maybe_convert_index(self.data)
@@ -1626,12 +1625,13 @@ def _make_plot(self):
16261625
left, right = _get_xlim(lines)
16271626
ax.set_xlim(left, right)
16281627

1629-
def _get_stacked_values(self, y, label):
1630-
if self.stacked:
1628+
@classmethod
1629+
def _get_stacked_values(cls, ax, y, label, stacked):
1630+
if stacked:
16311631
if (y >= 0).all():
1632-
return self._pos_prior + y
1632+
return ax._pos_prior + y
16331633
elif (y <= 0).all():
1634-
return self._neg_prior + y
1634+
return ax._neg_prior + y
16351635
else:
16361636
raise ValueError('When stacked is True, each column must be either all positive or negative.'
16371637
'{0} contains both positive and negative values'.format(label))
@@ -1640,13 +1640,15 @@ def _get_stacked_values(self, y, label):
16401640

16411641
def _get_plot_function(self):
16421642
f = MPLPlot._get_plot_function(self)
1643+
stacked = self.stacked
1644+
subplots = self.subplots
16431645
def plotf(ax, x, y, style=None, column_num=None, **kwds):
16441646
# column_num is used to get the target column from protf in line and area plots
1645-
if column_num == 0:
1646-
self._initialize_prior(len(self.data))
1647-
y_values = self._get_stacked_values(y, kwds['label'])
1647+
if not hasattr(ax, '_pos_prior') or column_num == 0:
1648+
LinePlot._initialize_prior(ax, len(y))
1649+
y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked)
16481650
lines = f(ax, x, y_values, style=style, **kwds)
1649-
self._update_prior(y)
1651+
LinePlot._update_prior(ax, y, stacked, subplots)
16501652
return lines
16511653
return plotf
16521654

@@ -1660,19 +1662,21 @@ def _plot(ax, x, data, style=None, **kwds):
16601662
return lines
16611663
return _plot
16621664

1663-
def _initialize_prior(self, n):
1664-
self._pos_prior = np.zeros(n)
1665-
self._neg_prior = np.zeros(n)
1665+
@classmethod
1666+
def _initialize_prior(cls, ax, n):
1667+
ax._pos_prior = np.zeros(n)
1668+
ax._neg_prior = np.zeros(n)
16661669

1667-
def _update_prior(self, y):
1668-
if self.stacked and not self.subplots:
1670+
@classmethod
1671+
def _update_prior(cls, ax, y, stacked, subplots):
1672+
if stacked and not subplots:
16691673
# tsplot resample may changedata length
1670-
if len(self._pos_prior) != len(y):
1671-
self._initialize_prior(len(y))
1674+
if len(ax._pos_prior) != len(y):
1675+
cls._initialize_prior(ax, len(y))
16721676
if (y >= 0).all():
1673-
self._pos_prior += y
1677+
ax._pos_prior += y
16741678
elif (y <= 0).all():
1675-
self._neg_prior += y
1679+
ax._neg_prior += y
16761680

16771681
def _maybe_convert_index(self, data):
16781682
# tsplot converts automatically, but don't want to convert index
@@ -1735,31 +1739,34 @@ def __init__(self, data, **kwargs):
17351739
self.kwds.setdefault('alpha', 0.5)
17361740

17371741
def _get_plot_function(self):
1742+
import matplotlib.pyplot as plt
17381743
if self.logy or self.loglog:
17391744
raise ValueError("Log-y scales are not supported in area plot")
17401745
else:
17411746
f = MPLPlot._get_plot_function(self)
1747+
stacked = self.stacked
1748+
subplots = self.subplots
17421749
def plotf(ax, x, y, style=None, column_num=None, **kwds):
1743-
if column_num == 0:
1744-
self._initialize_prior(len(self.data))
1745-
y_values = self._get_stacked_values(y, kwds['label'])
1750+
if not hasattr(ax, '_pos_prior') or column_num == 0:
1751+
LinePlot._initialize_prior(ax, len(y))
1752+
y_values = LinePlot._get_stacked_values(ax, y, kwds['label'], stacked)
17461753
lines = f(ax, x, y_values, style=style, **kwds)
17471754

17481755
# get data from the line to get coordinates for fill_between
17491756
xdata, y_values = lines[0].get_data(orig=False)
17501757

17511758
if (y >= 0).all():
1752-
start = self._pos_prior
1759+
start = ax._pos_prior
17531760
elif (y <= 0).all():
1754-
start = self._neg_prior
1761+
start = ax._neg_prior
17551762
else:
17561763
start = np.zeros(len(y))
17571764

17581765
if not 'color' in kwds:
17591766
kwds['color'] = lines[0].get_color()
17601767

1761-
self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
1762-
self._update_prior(y)
1768+
plt.Axes.fill_between(ax, xdata, start, y_values, **kwds)
1769+
LinePlot._update_prior(ax, y, stacked, subplots)
17631770
return lines
17641771

17651772
return plotf
@@ -1950,15 +1957,20 @@ def _args_adjust(self):
19501957
self.bottom = np.array(self.bottom)
19511958

19521959
def _get_plot_function(self):
1960+
import matplotlib.pyplot as plt
1961+
bins = self.bins
1962+
bottom = self.bottom
1963+
stacked = self.stacked
1964+
subplots = self.subplots
19531965
def plotf(ax, y, style=None, column_num=None, **kwds):
1954-
if column_num == 0:
1955-
self._initialize_prior(len(self.bins) - 1)
1966+
if not hasattr(ax, '_pos_prior') or column_num == 0:
1967+
LinePlot._initialize_prior(ax, len(self.bins) - 1)
19561968
y = y[~com.isnull(y)]
1957-
bottom = self._pos_prior + self.bottom
1969+
new_bottom = ax._pos_prior + bottom
19581970
# ignore style
1959-
n, bins, patches = self.plt.Axes.hist(ax, y, bins=self.bins,
1960-
bottom=bottom, **kwds)
1961-
self._update_prior(n)
1971+
n, new_bins, patches = plt.Axes.hist(ax, y, bins=bins,
1972+
bottom=new_bottom, **kwds)
1973+
LinePlot._update_prior(ax, n, stacked, subplots)
19621974
return patches
19631975
return plotf
19641976

0 commit comments

Comments
 (0)