diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index be969f3796935..6c035b816a9e9 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -241,5 +241,8 @@ Scatter plot matrix from pandas.tools.plotting import scatter_matrix df = DataFrame(np.random.randn(1000, 4), columns=['a', 'b', 'c', 'd']) - @savefig scatter_matrix_ex.png width=6in - scatter_matrix(df, alpha=0.2, figsize=(8, 8)) + @savefig scatter_matrix_kde.png width=6in + scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='kde') + + @savefig scatter_matrix_hist.png width=6in + scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='hist') \ No newline at end of file diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 8e987f35d42e7..6fe1f93448671 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -214,6 +214,8 @@ def scat(**kwds): _check_plot_works(scat) _check_plot_works(scat, marker='+') _check_plot_works(scat, vmin=0) + _check_plot_works(scat, diagonal='kde') + _check_plot_works(scat, diagonal='hist') def scat2(x, y, by=None, ax=None, figsize=None): return plt.scatter_plot(df, x, y, by, ax, figsize=None) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 8168e1367f962..6e0bad14a580d 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -3,6 +3,7 @@ from itertools import izip import numpy as np +from scipy import stats from pandas.util.decorators import cache_readonly import pandas.core.common as com @@ -12,12 +13,19 @@ from pandas.tseries.offsets import DateOffset def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, - **kwds): + diagonal='hist', **kwds): """ Draw a matrix of scatter plots. Parameters ---------- + alpha : amount of transparency applied + figsize : a tuple (width, height) in inches + ax : Matplotlib axis object + grid : setting this to True will show the grid + diagonal : pick between 'kde' and 'hist' for + either Kernel Density Estimation or Histogram + plon in the diagonal kwds : other plotting keyword arguments To be passed to scatter function @@ -36,7 +44,18 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, for i, a in zip(range(n), df.columns): for j, b in zip(range(n), df.columns): - axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) + if i == j: + # Deal with the diagonal by drawing a histogram there. + if diagonal == 'hist': + axes[i, j].hist(df[a]) + elif diagonal == 'kde': + y = df[a] + gkde = stats.gaussian_kde(y) + ind = np.linspace(min(y), max(y), 1000) + axes[i, j].plot(ind, gkde.evaluate(ind), **kwds) + else: + axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) + axes[i, j].set_xlabel('') axes[i, j].set_ylabel('') axes[i, j].set_xticklabels([]) @@ -44,7 +63,7 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, ticks = df.index is_datetype = ticks.inferred_type in ('datetime', 'date', - 'datetime64') + 'datetime64') if ticks.is_numeric() or is_datetype: """ @@ -87,13 +106,6 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, axes[i, j].grid(b=grid) - # ensure {x,y}lim off diagonal are the same as diagonal - for i in range(n): - for j in range(n): - if i != j: - axes[i, j].set_xlim(axes[j, j].get_xlim()) - axes[i, j].set_ylim(axes[i, i].get_ylim()) - return axes def _gca():