From 67e380a9a92e2338f1ad1e624fedb07c3ce28124 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 19 Aug 2014 23:44:39 -0700 Subject: [PATCH] API: support `c` and `colormap` args for DataFrame.plot with kind='scatter' `matplotlib.pyplot.scatter` supports the argument `c` for setting the color of each point. This patch lets you easily set it by giving a column name (currently you need to supply an ndarray to make it work, since pandas doesn't use it): df.plot('x', 'y', c='z', kind='scatter') vs df.plot('x', 'y', c=df['z'].values, kind='scatter') While I was at it, I noticed that `kind='scatter'` did not support the `colormap` argument that some of the other methods support (notably `kind='hexbin'`). So I added it, too. This change should be almost entirely backwards compatible, unless folks are naming columns in their data frame valid matplotlib colors and using the same color name for the `c` argument. A colorbar will also be added automatically if relevant. --- doc/source/v0.15.0.txt | 2 ++ doc/source/visualization.rst | 8 ++++++++ pandas/tests/test_graphics.py | 28 ++++++++++++++++++++++++++ pandas/tools/plotting.py | 38 ++++++++++++++++++++++++++++++----- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index bfd484b363dd2..2871d2f628659 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -435,6 +435,8 @@ Enhancements - Added ``layout`` keyword to ``DataFrame.plot`` (:issue:`6667`) - Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`) +- Added support for ``c``, ``colormap`` and ``colorbar`` arguments for + ``DataFrame.plot`` with ``kind='scatter'`` (:issue:`7780`) - ``PeriodIndex`` supports ``resolution`` as the same as ``DatetimeIndex`` (:issue:`7708`) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index 1cce55cd53e11..d845ae38f05c2 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -521,6 +521,14 @@ It is recommended to specify ``color`` and ``label`` keywords to distinguish eac df.plot(kind='scatter', x='c', y='d', color='DarkGreen', label='Group 2', ax=ax); +The keyword ``c`` may be given as the name of a column to provide colors for +each point: + +.. ipython:: python + + @savefig scatter_plot_colored.png + df.plot(kind='scatter', x='a', y='b', c='c', s=50); + You can pass other keywords supported by matplotlib ``scatter``. Below example shows a bubble chart using a dataframe column values as bubble size. diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 131edf499ff18..3211998b42300 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -1497,6 +1497,34 @@ def test_plot_scatter(self): axes = df.plot(x='x', y='y', kind='scatter', subplots=True) self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + @slow + def test_plot_scatter_with_c(self): + df = DataFrame(randn(6, 4), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z', 'four']) + + axes = [df.plot(kind='scatter', x='x', y='y', c='z'), + df.plot(kind='scatter', x=0, y=1, c=2)] + for ax in axes: + # default to RdBu + self.assertEqual(ax.collections[0].cmap.name, 'RdBu') + # n.b. there appears to be no public method to get the colorbar + # label + self.assertEqual(ax.collections[0].colorbar._label, 'z') + + cm = 'cubehelix' + ax = df.plot(kind='scatter', x='x', y='y', c='z', colormap=cm) + self.assertEqual(ax.collections[0].cmap.name, cm) + + # verify turning off colorbar works + ax = df.plot(kind='scatter', x='x', y='y', c='z', colorbar=False) + self.assertIs(ax.collections[0].colorbar, None) + + # verify that we can still plot a solid color + ax = df.plot(x=0, y=1, c='red', kind='scatter') + self.assertIs(ax.collections[0].colorbar, None) + self._check_colors(ax.collections, facecolors=['r']) + @slow def test_plot_bar(self): df = DataFrame(randn(6, 4), diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 56316ac726c8a..7a68da3ad14f2 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1368,32 +1368,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): class ScatterPlot(MPLPlot): _layout_type = 'single' - def __init__(self, data, x, y, **kwargs): + def __init__(self, data, x, y, c=None, **kwargs): MPLPlot.__init__(self, data, **kwargs) - self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor']) if x is None or y is None: raise ValueError( 'scatter requires and x and y column') if com.is_integer(x) and not self.data.columns.holds_integer(): x = self.data.columns[x] if com.is_integer(y) and not self.data.columns.holds_integer(): y = self.data.columns[y] + if com.is_integer(c) and not self.data.columns.holds_integer(): + c = self.data.columns[c] self.x = x self.y = y + self.c = c @property def nseries(self): return 1 def _make_plot(self): - x, y, data = self.x, self.y, self.data + import matplotlib.pyplot as plt + + x, y, c, data = self.x, self.y, self.c, self.data ax = self.axes[0] + # plot a colorbar only if a colormap is provided or necessary + cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns) + + # pandas uses colormap, matplotlib uses cmap. + cmap = self.colormap or 'RdBu' + cmap = plt.cm.get_cmap(cmap) + + if c is None: + c_values = self.plt.rcParams['patch.facecolor'] + elif c in self.data.columns: + c_values = self.data[c].values + else: + c_values = c + if self.legend and hasattr(self, 'label'): label = self.label else: label = None - scatter = ax.scatter(data[x].values, data[y].values, label=label, - **self.kwds) + scatter = ax.scatter(data[x].values, data[y].values, c=c_values, + label=label, cmap=cmap, **self.kwds) + if cb: + img = ax.collections[0] + cb_label = c if c in self.data.columns else '' + self.fig.colorbar(img, ax=ax, label=cb_label) + self._add_legend_handle(scatter, label) errors_x = self._get_errorbars(label=x, index=0, yerr=False) @@ -2259,6 +2282,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, colormap : str or matplotlib colormap object, default None Colormap to select colors from. If string, load colormap with that name from matplotlib. + colorbar : boolean, optional + If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots) position : float Specify relative alignments for bar plot layout. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center) @@ -2285,6 +2310,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, `C` specifies the value at each `(x, y)` point and `reduce_C_function` is a function of one argument that reduces all the values in a bin to a single number (e.g. `mean`, `max`, `sum`, `std`). + + If `kind`='scatter' and the argument `c` is the name of a dataframe column, + the values of that column are used to color each point. """ kind = _get_standard_kind(kind.lower().strip())