Skip to content

Scatterplot Update for #3473 #4930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 21 additions & 31 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,37 +419,6 @@ def test_explicit_label(self):
ax = df.plot(x='a', y='b', label='LABEL')
self.assertEqual(ax.xaxis.get_label().get_text(), 'LABEL')

@slow
def test_plot_xy(self):
import matplotlib.pyplot as plt
# columns.inferred_type == 'string'
df = tm.makeTimeDataFrame()
self._check_data(df.plot(x=0, y=1),
df.set_index('A')['B'].plot())
self._check_data(df.plot(x=0), df.set_index('A').plot())
self._check_data(df.plot(y=0), df.B.plot())
self._check_data(df.plot(x='A', y='B'),
df.set_index('A').B.plot())
self._check_data(df.plot(x='A'), df.set_index('A').plot())
self._check_data(df.plot(y='B'), df.B.plot())

# columns.inferred_type == 'integer'
df.columns = lrange(1, len(df.columns) + 1)
self._check_data(df.plot(x=1, y=2),
df.set_index(1)[2].plot())
self._check_data(df.plot(x=1), df.set_index(1).plot())
self._check_data(df.plot(y=1), df[1].plot())

# figsize and title
ax = df.plot(x=1, y=2, title='Test', figsize=(16, 8))

self.assertEqual(ax.title.get_text(), 'Test')
assert_array_equal(np.round(ax.figure.get_size_inches()),
np.array((16., 8.)))

# columns.inferred_type == 'mixed'
# TODO add MultiIndex test

@slow
def test_xcompat(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should have been removed?

import pandas as pd
Expand Down Expand Up @@ -534,6 +503,27 @@ def test_subplots(self):
[self.assert_(label.get_visible())
for label in ax.get_yticklabels()]

@slow
def test_plot_scatter(self):
from matplotlib.pylab import close
df = DataFrame(randn(6, 4),
index=list(string.ascii_letters[:6]),
columns=['x', 'y', 'z', 'four'])

_check_plot_works(df.plot, x='x', y='y', kind='scatter')
_check_plot_works(df.plot, x='x', y='y', kind='scatter', legend=False)
_check_plot_works(df.plot, x='x', y='y', kind='scatter', subplots=True)
_check_plot_works(df.plot, x='x', y='y', kind='scatter', stacked=True)

df = DataFrame(randn(10, 15),
index=list(string.ascii_letters[:10]),
columns=lrange(15))
_check_plot_works(df.plot, x=1, y=2, kind='scatter')

df = DataFrame({'a': [0, 1], 'b': [1, 0]})
_check_plot_works(df.plot, x='a',y='b',kind='scatter')


@slow
def test_plot_bar(self):
from matplotlib.pylab import close
Expand Down
75 changes: 57 additions & 18 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,29 @@ def _post_plot_logic(self):
for ax in self.axes:
ax.legend(loc='best')


class ScatterPlot(MPLPlot):
def __init__(self, data, **kwargs):
MPLPlot.__init__(self, data, **kwargs)
if 'x' not in kwargs and 'y' not in kwargs:
raise ValueError( 'Scatterplot requires and X and Y column')

def _make_plot(self):
plotf = self._get_plot_function()
colors = self._get_colors()

for i, (label, y) in enumerate(self._iter_data()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have y assigned here, but then below you bind y using self.kwds['y']?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same deal. 'y' is the column where y values come from. Then you iterate through the rows and use 'y'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're misunderstanding how name binding works. See my second comment below.

ax = self._get_ax(i)
#kwds = self.kwds.copy()
x, y = self.kwds['x'], self.kwds['y']
#print x, y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two

ax = ax.scatter(x, y)
style = self._get_style(i, label)

def _post_plot_logic(self):
if self.subplots and self.legend:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you only want a legend to be shown if subplots is True and legend is True? What if subplots is False? Then the legend keyword argument would have no effect.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this. I used the line or barplot class as template. I'm not really familiar with subplotting

for ax in self.axes:
ax.legend(loc='best')

class LinePlot(MPLPlot):

def __init__(self, data, **kwargs):
Expand Down Expand Up @@ -1554,7 +1576,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
secondary_y=False, **kwds):

"""
Make line or bar plot of DataFrame's series with the index on the x-axis
Make line, bar, or scater plots of DataFrame series with the index on the x-axis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: scater -> scatter

using matplotlib / pylab.

Parameters
Expand Down Expand Up @@ -1585,10 +1607,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
ax : matplotlib axis object, default None
style : list or dict
matplotlib line style per column
kind : {'line', 'bar', 'barh', 'kde', 'density'}
kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'}
bar : vertical bar plot
barh : horizontal bar plot
kde/density : Kernel Density Estimation plot
scatter: scatter plot
logx : boolean, default False
For line plots, use log scaling on x axis
logy : boolean, default False
Expand Down Expand Up @@ -1624,6 +1647,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
elif kind == 'scatter':
klass = ScatterPlot
else:
raise ValueError('Invalid chart type given %s' % kind)

Expand All @@ -1639,21 +1664,35 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
label = kwds.pop('label', label)
ser = frame[y]
ser.index.name = label
return plot_series(ser, label=label, kind=kind,
use_index=use_index,
rot=rot, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, ax=ax, style=style,
grid=grid, logx=logx, logy=logy,
secondary_y=secondary_y, title=title,
figsize=figsize, fontsize=fontsize, **kwds)

plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)
if kind != 'scatter':
return plot_series(ser, label=label, kind=kind,
use_index=use_index,
rot=rot, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, ax=ax, style=style,
grid=grid, logx=logx, logy=logy,
secondary_y=secondary_y, title=title,
figsize=figsize, fontsize=fontsize, **kwds)
if kind == 'scatter':
plot_obj = klass(frame, x=frame.index, y=ser,
kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)

else:
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
use_index=use_index, sharex=sharex, sharey=sharey,
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)



plot_obj.generate()
plot_obj.draw()
if subplots:
Expand Down