diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index bfd484b363dd2..2871d2f628659 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -435,6 +435,8 @@ Enhancements - Added ``layout`` keyword to ``DataFrame.plot`` (:issue:`6667`) - Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`) +- Added support for ``c``, ``colormap`` and ``colorbar`` arguments for + ``DataFrame.plot`` with ``kind='scatter'`` (:issue:`7780`) - ``PeriodIndex`` supports ``resolution`` as the same as ``DatetimeIndex`` (:issue:`7708`) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index 1cce55cd53e11..d845ae38f05c2 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -521,6 +521,14 @@ It is recommended to specify ``color`` and ``label`` keywords to distinguish eac df.plot(kind='scatter', x='c', y='d', color='DarkGreen', label='Group 2', ax=ax); +The keyword ``c`` may be given as the name of a column to provide colors for +each point: + +.. ipython:: python + + @savefig scatter_plot_colored.png + df.plot(kind='scatter', x='a', y='b', c='c', s=50); + You can pass other keywords supported by matplotlib ``scatter``. Below example shows a bubble chart using a dataframe column values as bubble size. diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 131edf499ff18..3211998b42300 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -1497,6 +1497,34 @@ def test_plot_scatter(self): axes = df.plot(x='x', y='y', kind='scatter', subplots=True) self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + @slow + def test_plot_scatter_with_c(self): + df = DataFrame(randn(6, 4), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z', 'four']) + + axes = [df.plot(kind='scatter', x='x', y='y', c='z'), + df.plot(kind='scatter', x=0, y=1, c=2)] + for ax in axes: + # default to RdBu + self.assertEqual(ax.collections[0].cmap.name, 'RdBu') + # n.b. there appears to be no public method to get the colorbar + # label + self.assertEqual(ax.collections[0].colorbar._label, 'z') + + cm = 'cubehelix' + ax = df.plot(kind='scatter', x='x', y='y', c='z', colormap=cm) + self.assertEqual(ax.collections[0].cmap.name, cm) + + # verify turning off colorbar works + ax = df.plot(kind='scatter', x='x', y='y', c='z', colorbar=False) + self.assertIs(ax.collections[0].colorbar, None) + + # verify that we can still plot a solid color + ax = df.plot(x=0, y=1, c='red', kind='scatter') + self.assertIs(ax.collections[0].colorbar, None) + self._check_colors(ax.collections, facecolors=['r']) + @slow def test_plot_bar(self): df = DataFrame(randn(6, 4), diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 56316ac726c8a..7a68da3ad14f2 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1368,32 +1368,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): class ScatterPlot(MPLPlot): _layout_type = 'single' - def __init__(self, data, x, y, **kwargs): + def __init__(self, data, x, y, c=None, **kwargs): MPLPlot.__init__(self, data, **kwargs) - self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor']) if x is None or y is None: raise ValueError( 'scatter requires and x and y column') if com.is_integer(x) and not self.data.columns.holds_integer(): x = self.data.columns[x] if com.is_integer(y) and not self.data.columns.holds_integer(): y = self.data.columns[y] + if com.is_integer(c) and not self.data.columns.holds_integer(): + c = self.data.columns[c] self.x = x self.y = y + self.c = c @property def nseries(self): return 1 def _make_plot(self): - x, y, data = self.x, self.y, self.data + import matplotlib.pyplot as plt + + x, y, c, data = self.x, self.y, self.c, self.data ax = self.axes[0] + # plot a colorbar only if a colormap is provided or necessary + cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns) + + # pandas uses colormap, matplotlib uses cmap. + cmap = self.colormap or 'RdBu' + cmap = plt.cm.get_cmap(cmap) + + if c is None: + c_values = self.plt.rcParams['patch.facecolor'] + elif c in self.data.columns: + c_values = self.data[c].values + else: + c_values = c + if self.legend and hasattr(self, 'label'): label = self.label else: label = None - scatter = ax.scatter(data[x].values, data[y].values, label=label, - **self.kwds) + scatter = ax.scatter(data[x].values, data[y].values, c=c_values, + label=label, cmap=cmap, **self.kwds) + if cb: + img = ax.collections[0] + cb_label = c if c in self.data.columns else '' + self.fig.colorbar(img, ax=ax, label=cb_label) + self._add_legend_handle(scatter, label) errors_x = self._get_errorbars(label=x, index=0, yerr=False) @@ -2259,6 +2282,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, colormap : str or matplotlib colormap object, default None Colormap to select colors from. If string, load colormap with that name from matplotlib. + colorbar : boolean, optional + If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots) position : float Specify relative alignments for bar plot layout. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center) @@ -2285,6 +2310,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, `C` specifies the value at each `(x, y)` point and `reduce_C_function` is a function of one argument that reduces all the values in a bin to a single number (e.g. `mean`, `max`, `sum`, `std`). + + If `kind`='scatter' and the argument `c` is the name of a dataframe column, + the values of that column are used to color each point. """ kind = _get_standard_kind(kind.lower().strip())