Skip to content

Commit 79607c5

Browse files
committed
Merge branch 'scatterplot' of https://github.com/zachcp/pandas into zachcp-scatterplot
2 parents d80894c + dfd92c7 commit 79607c5

File tree

4 files changed

+94
-38
lines changed

4 files changed

+94
-38
lines changed

doc/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ API Changes
210210
- Default export for ``to_clipboard`` is now csv with a sep of `\t` for
211211
compat (:issue:`3368`)
212212
- ``at`` now will enlarge the object inplace (and return the same) (:issue:`2578`)
213+
- ``DataFrame.plot`` will scatter plot x versus y by passing ``kind='scatter'`` (:issue:`2215`)
213214

214215
- ``HDFStore``
215216

doc/source/v0.13.0.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ Enhancements
537537
- ``to_csv`` now takes a ``date_format`` keyword argument that specifies how
538538
output datetime objects should be formatted. Datetimes encountered in the
539539
index, columns, and values will all have this formatting applied. (:issue:`4313`)
540+
- ``DataFrame.plot`` will scatter plot x versus y by passing ``kind='scatter'`` (:issue:`2215`)
540541

541542
.. _whatsnew_0130.experimental:
542543

@@ -654,7 +655,7 @@ Experimental
654655
against extremely large datasets. :ref:`See the docs <io.bigquery>`
655656

656657
.. code-block:: python
657-
658+
658659
from pandas.io import gbq
659660

660661
# A query to select the average monthly temperatures in the
@@ -665,8 +666,8 @@ Experimental
665666
query = """SELECT station_number as STATION,
666667
month as MONTH, AVG(mean_temp) as MEAN_TEMP
667668
FROM publicdata:samples.gsod
668-
WHERE YEAR = 2000
669-
GROUP BY STATION, MONTH
669+
WHERE YEAR = 2000
670+
GROUP BY STATION, MONTH
670671
ORDER BY STATION, MONTH ASC"""
671672

672673
# Fetch the result set for this query
@@ -675,7 +676,7 @@ Experimental
675676
# To find this, see your dashboard:
676677
# https://code.google.com/apis/console/b/0/?noredirect
677678
projectid = xxxxxxxxx;
678-
679+
679680
df = gbq.read_gbq(query, project_id = projectid)
680681

681682
# Use pandas to process and reshape the dataset
@@ -686,9 +687,9 @@ Experimental
686687

687688
The resulting dataframe is::
688689

689-
> df3
690+
> df3
690691
Min Tem Mean Temp Max Temp
691-
MONTH
692+
MONTH
692693
1 -53.336667 39.827892 89.770968
693694
2 -49.837500 43.685219 93.437932
694695
3 -77.926087 48.708355 96.099998

pandas/tests/test_graphics.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_plot_xy(self):
449449

450450
# columns.inferred_type == 'mixed'
451451
# TODO add MultiIndex test
452-
452+
453453
@slow
454454
def test_xcompat(self):
455455
import pandas as pd
@@ -534,6 +534,21 @@ def test_subplots(self):
534534
[self.assert_(label.get_visible())
535535
for label in ax.get_yticklabels()]
536536

537+
@slow
538+
def test_plot_scatter(self):
539+
from matplotlib.pylab import close
540+
df = DataFrame(randn(6, 4),
541+
index=list(string.ascii_letters[:6]),
542+
columns=['x', 'y', 'z', 'four'])
543+
544+
_check_plot_works(df.plot, x='x', y='y', kind='scatter')
545+
_check_plot_works(df.plot, x=1, y=2, kind='scatter')
546+
547+
with tm.assertRaises(ValueError):
548+
df.plot(x='x', kind='scatter')
549+
with tm.assertRaises(ValueError):
550+
df.plot(y='y', kind='scatter')
551+
537552
@slow
538553
def test_plot_bar(self):
539554
from matplotlib.pylab import close

pandas/tools/plotting.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def _gcf():
322322
import matplotlib.pyplot as plt
323323
return plt.gcf()
324324

325-
326325
def _get_marker_compat(marker):
327326
import matplotlib.lines as mlines
328327
import matplotlib as mpl
@@ -1201,7 +1200,32 @@ def _post_plot_logic(self):
12011200
for ax in self.axes:
12021201
ax.legend(loc='best')
12031202

1204-
1203+
class ScatterPlot(MPLPlot):
1204+
def __init__(self, data, x, y, **kwargs):
1205+
MPLPlot.__init__(self, data, **kwargs)
1206+
self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
1207+
if x is None or y is None:
1208+
raise ValueError( 'scatter requires and x and y column')
1209+
if com.is_integer(x) and not self.data.columns.holds_integer():
1210+
x = self.data.columns[x]
1211+
if com.is_integer(y) and not self.data.columns.holds_integer():
1212+
y = self.data.columns[y]
1213+
self.x = x
1214+
self.y = y
1215+
1216+
1217+
def _make_plot(self):
1218+
x, y, data = self.x, self.y, self.data
1219+
ax = self.axes[0]
1220+
ax.scatter(data[x].values, data[y].values, **self.kwds)
1221+
1222+
def _post_plot_logic(self):
1223+
ax = self.axes[0]
1224+
x, y = self.x, self.y
1225+
ax.set_ylabel(com.pprint_thing(y))
1226+
ax.set_xlabel(com.pprint_thing(x))
1227+
1228+
12051229
class LinePlot(MPLPlot):
12061230

12071231
def __init__(self, data, **kwargs):
@@ -1562,7 +1586,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
15621586
secondary_y=False, **kwds):
15631587

15641588
"""
1565-
Make line or bar plot of DataFrame's series with the index on the x-axis
1589+
Make line, bar, or scatter plots of DataFrame series with the index on the x-axis
15661590
using matplotlib / pylab.
15671591
15681592
Parameters
@@ -1593,10 +1617,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
15931617
ax : matplotlib axis object, default None
15941618
style : list or dict
15951619
matplotlib line style per column
1596-
kind : {'line', 'bar', 'barh', 'kde', 'density'}
1620+
kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'}
15971621
bar : vertical bar plot
15981622
barh : horizontal bar plot
15991623
kde/density : Kernel Density Estimation plot
1624+
scatter: scatter plot
16001625
logx : boolean, default False
16011626
For line plots, use log scaling on x axis
16021627
logy : boolean, default False
@@ -1632,36 +1657,50 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
16321657
klass = BarPlot
16331658
elif kind == 'kde':
16341659
klass = KdePlot
1660+
elif kind == 'scatter':
1661+
klass = ScatterPlot
16351662
else:
16361663
raise ValueError('Invalid chart type given %s' % kind)
16371664

1638-
if x is not None:
1639-
if com.is_integer(x) and not frame.columns.holds_integer():
1640-
x = frame.columns[x]
1641-
frame = frame.set_index(x)
1642-
1643-
if y is not None:
1644-
if com.is_integer(y) and not frame.columns.holds_integer():
1645-
y = frame.columns[y]
1646-
label = x if x is not None else frame.index.name
1647-
label = kwds.pop('label', label)
1648-
ser = frame[y]
1649-
ser.index.name = label
1650-
return plot_series(ser, label=label, kind=kind,
1651-
use_index=use_index,
1652-
rot=rot, xticks=xticks, yticks=yticks,
1653-
xlim=xlim, ylim=ylim, ax=ax, style=style,
1654-
grid=grid, logx=logx, logy=logy,
1655-
secondary_y=secondary_y, title=title,
1656-
figsize=figsize, fontsize=fontsize, **kwds)
1657-
1658-
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
1659-
legend=legend, ax=ax, style=style, fontsize=fontsize,
1660-
use_index=use_index, sharex=sharex, sharey=sharey,
1661-
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
1662-
title=title, grid=grid, figsize=figsize, logx=logx,
1663-
logy=logy, sort_columns=sort_columns,
1664-
secondary_y=secondary_y, **kwds)
1665+
if kind == 'scatter':
1666+
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
1667+
rot=rot,legend=legend, ax=ax, style=style,
1668+
fontsize=fontsize, use_index=use_index, sharex=sharex,
1669+
sharey=sharey, xticks=xticks, yticks=yticks,
1670+
xlim=xlim, ylim=ylim, title=title, grid=grid,
1671+
figsize=figsize, logx=logx, logy=logy,
1672+
sort_columns=sort_columns, secondary_y=secondary_y,
1673+
**kwds)
1674+
else:
1675+
if x is not None:
1676+
if com.is_integer(x) and not frame.columns.holds_integer():
1677+
x = frame.columns[x]
1678+
frame = frame.set_index(x)
1679+
1680+
if y is not None:
1681+
if com.is_integer(y) and not frame.columns.holds_integer():
1682+
y = frame.columns[y]
1683+
label = x if x is not None else frame.index.name
1684+
label = kwds.pop('label', label)
1685+
ser = frame[y]
1686+
ser.index.name = label
1687+
return plot_series(ser, label=label, kind=kind,
1688+
use_index=use_index,
1689+
rot=rot, xticks=xticks, yticks=yticks,
1690+
xlim=xlim, ylim=ylim, ax=ax, style=style,
1691+
grid=grid, logx=logx, logy=logy,
1692+
secondary_y=secondary_y, title=title,
1693+
figsize=figsize, fontsize=fontsize, **kwds)
1694+
1695+
else:
1696+
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
1697+
legend=legend, ax=ax, style=style, fontsize=fontsize,
1698+
use_index=use_index, sharex=sharex, sharey=sharey,
1699+
xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
1700+
title=title, grid=grid, figsize=figsize, logx=logx,
1701+
logy=logy, sort_columns=sort_columns,
1702+
secondary_y=secondary_y, **kwds)
1703+
16651704
plot_obj.generate()
16661705
plot_obj.draw()
16671706
if subplots:

0 commit comments

Comments
 (0)