From bde3ff62e94216e4f116322e99c9f78367ef8956 Mon Sep 17 00:00:00 2001 From: zach powers Date: Thu, 26 Sep 2013 15:06:55 -0400 Subject: [PATCH 1/4] add ScatterPlot klass --- pandas/tools/plotting.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index ce75e755a313f..d5b04b10d7f99 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1190,7 +1190,33 @@ def _post_plot_logic(self): for ax in self.axes: ax.legend(loc='best') - +class ScatterPlot(MPLPlot): + def __init__(self, data, **kwargs): + MPLPlot.__init__(self, data, **kwargs) + #kwargs = self.kwargs + #print kwargs + ## check ot see that x and y are passed as keywords + if not ('x' and'y') in kwargs: + msg ='Scatterplot requires and X and Y column' + raise Exception(msg) + + def _make_plot(self): + plotf = self._get_plot_function() + colors = self._get_colors() + + for i, (label, y) in enumerate(self._iter_data()): + ax = self._get_ax(i) + #kwds = self.kwds.copy() + x, y = self.kwds['x'], self.kwds['y'] + #print x, y + ax = ax.scatter(x, y) + style = self._get_style(i, label) + + def _post_plot_logic(self): + if self.subplots and self.legend: + for ax in self.axes: + ax.legend(loc='best') + class LinePlot(MPLPlot): def __init__(self, data, **kwargs): @@ -1621,6 +1647,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, klass = BarPlot elif kind == 'kde': klass = KdePlot + elif kind== 'scatter': + klass = 'ScatterPlot' else: raise ValueError('Invalid chart type given %s' % kind) From c85158f3d284c6220c226d698918ac4f2beb19a8 Mon Sep 17 00:00:00 2001 From: zach powers Date: Thu, 26 Sep 2013 15:43:03 -0400 Subject: [PATCH 2/4] ScatterPlot class os it is called correctly --- pandas/tools/plotting.py | 53 ++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index d5b04b10d7f99..0139603eb2fc5 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1577,7 +1577,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, secondary_y=False, **kwds): """ - Make line or bar plot of DataFrame's series with the index on the x-axis + Make line, bar, or scater plots of DataFrame series with the index on the x-axis using matplotlib / pylab. Parameters @@ -1608,10 +1608,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, ax : matplotlib axis object, default None style : list or dict matplotlib line style per column - kind : {'line', 'bar', 'barh', 'kde', 'density'} + kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'} bar : vertical bar plot barh : horizontal bar plot kde/density : Kernel Density Estimation plot + scatter: scatter plot logx : boolean, default False For line plots, use log scaling on x axis logy : boolean, default False @@ -1647,8 +1648,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, klass = BarPlot elif kind == 'kde': klass = KdePlot - elif kind== 'scatter': - klass = 'ScatterPlot' + elif kind == 'scatter': + klass = ScatterPlot else: raise ValueError('Invalid chart type given %s' % kind) @@ -1664,21 +1665,35 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, label = kwds.pop('label', label) ser = frame[y] ser.index.name = label - return plot_series(ser, label=label, kind=kind, - use_index=use_index, - rot=rot, xticks=xticks, yticks=yticks, - xlim=xlim, ylim=ylim, ax=ax, style=style, - grid=grid, logx=logx, logy=logy, - secondary_y=secondary_y, title=title, - figsize=figsize, fontsize=fontsize, **kwds) - - plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, - legend=legend, ax=ax, style=style, fontsize=fontsize, - use_index=use_index, sharex=sharex, sharey=sharey, - xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, - title=title, grid=grid, figsize=figsize, logx=logx, - logy=logy, sort_columns=sort_columns, - secondary_y=secondary_y, **kwds) + if kind != 'scatter': + return plot_series(ser, label=label, kind=kind, + use_index=use_index, + rot=rot, xticks=xticks, yticks=yticks, + xlim=xlim, ylim=ylim, ax=ax, style=style, + grid=grid, logx=logx, logy=logy, + secondary_y=secondary_y, title=title, + figsize=figsize, fontsize=fontsize, **kwds) + if kind == 'scatter': + plot_obj = klass(frame, x=frame.index, y=ser, + kind=kind, subplots=subplots, rot=rot, + legend=legend, ax=ax, style=style, fontsize=fontsize, + use_index=use_index, sharex=sharex, sharey=sharey, + xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, + title=title, grid=grid, figsize=figsize, logx=logx, + logy=logy, sort_columns=sort_columns, + secondary_y=secondary_y, **kwds) + + else: + plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, + legend=legend, ax=ax, style=style, fontsize=fontsize, + use_index=use_index, sharex=sharex, sharey=sharey, + xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, + title=title, grid=grid, figsize=figsize, logx=logx, + logy=logy, sort_columns=sort_columns, + secondary_y=secondary_y, **kwds) + + + plot_obj.generate() plot_obj.draw() if subplots: From e008049c6fee0b34f4ac6e90d62ae523da679be8 Mon Sep 17 00:00:00 2001 From: zach powers Date: Thu, 26 Sep 2013 16:30:18 -0400 Subject: [PATCH 3/4] some small cahnges --- pandas/tools/plotting.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 0139603eb2fc5..b07f77234e729 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1193,12 +1193,8 @@ def _post_plot_logic(self): class ScatterPlot(MPLPlot): def __init__(self, data, **kwargs): MPLPlot.__init__(self, data, **kwargs) - #kwargs = self.kwargs - #print kwargs - ## check ot see that x and y are passed as keywords - if not ('x' and'y') in kwargs: - msg ='Scatterplot requires and X and Y column' - raise Exception(msg) + if 'x' not in kwargs and 'y' not in kwargs: + raise ValueError( 'Scatterplot requires and X and Y column') def _make_plot(self): plotf = self._get_plot_function() From 30da599b27fa5d26236ccc03fce439ead87567a3 Mon Sep 17 00:00:00 2001 From: zach powers Date: Thu, 26 Sep 2013 17:42:55 -0400 Subject: [PATCH 4/4] added a set of scatterplot tests --- pandas/tests/test_graphics.py | 52 ++++++++++++++--------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 49dc31514da7a..2d2218c60119f 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -419,37 +419,6 @@ def test_explicit_label(self): ax = df.plot(x='a', y='b', label='LABEL') self.assertEqual(ax.xaxis.get_label().get_text(), 'LABEL') - @slow - def test_plot_xy(self): - import matplotlib.pyplot as plt - # columns.inferred_type == 'string' - df = tm.makeTimeDataFrame() - self._check_data(df.plot(x=0, y=1), - df.set_index('A')['B'].plot()) - self._check_data(df.plot(x=0), df.set_index('A').plot()) - self._check_data(df.plot(y=0), df.B.plot()) - self._check_data(df.plot(x='A', y='B'), - df.set_index('A').B.plot()) - self._check_data(df.plot(x='A'), df.set_index('A').plot()) - self._check_data(df.plot(y='B'), df.B.plot()) - - # columns.inferred_type == 'integer' - df.columns = lrange(1, len(df.columns) + 1) - self._check_data(df.plot(x=1, y=2), - df.set_index(1)[2].plot()) - self._check_data(df.plot(x=1), df.set_index(1).plot()) - self._check_data(df.plot(y=1), df[1].plot()) - - # figsize and title - ax = df.plot(x=1, y=2, title='Test', figsize=(16, 8)) - - self.assertEqual(ax.title.get_text(), 'Test') - assert_array_equal(np.round(ax.figure.get_size_inches()), - np.array((16., 8.))) - - # columns.inferred_type == 'mixed' - # TODO add MultiIndex test - @slow def test_xcompat(self): import pandas as pd @@ -534,6 +503,27 @@ def test_subplots(self): [self.assert_(label.get_visible()) for label in ax.get_yticklabels()] + @slow + def test_plot_scatter(self): + from matplotlib.pylab import close + df = DataFrame(randn(6, 4), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z', 'four']) + + _check_plot_works(df.plot, x='x', y='y', kind='scatter') + _check_plot_works(df.plot, x='x', y='y', kind='scatter', legend=False) + _check_plot_works(df.plot, x='x', y='y', kind='scatter', subplots=True) + _check_plot_works(df.plot, x='x', y='y', kind='scatter', stacked=True) + + df = DataFrame(randn(10, 15), + index=list(string.ascii_letters[:10]), + columns=lrange(15)) + _check_plot_works(df.plot, x=1, y=2, kind='scatter') + + df = DataFrame({'a': [0, 1], 'b': [1, 0]}) + _check_plot_works(df.plot, x='a',y='b',kind='scatter') + + @slow def test_plot_bar(self): from matplotlib.pylab import close