From 32576a2291a98971c21efba48fa01340a1bf824a Mon Sep 17 00:00:00 2001 From: Vytautas Jancauskas Date: Sun, 13 May 2012 00:17:31 +0300 Subject: [PATCH 1/5] Changes to plotting scatter matrix diagonals --- pandas/tools/plotting.py | 122 ++++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 54 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 8168e1367f962..ebb00e0e2efe9 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -3,6 +3,7 @@ from itertools import izip import numpy as np +from scipy import stats from pandas.util.decorators import cache_readonly import pandas.core.common as com @@ -12,7 +13,7 @@ from pandas.tseries.offsets import DateOffset def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, - **kwds): + diagonal='hist', **kwds): """ Draw a matrix of scatter plots. @@ -36,64 +37,77 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, for i, a in zip(range(n), df.columns): for j, b in zip(range(n), df.columns): - axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) - axes[i, j].set_xlabel('') - axes[i, j].set_ylabel('') - axes[i, j].set_xticklabels([]) - axes[i, j].set_yticklabels([]) - ticks = df.index - - is_datetype = ticks.inferred_type in ('datetime', 'date', + if i == j: + # Deal with the diagonal by drawing a histogram there. + if diagonal == 'hist': + axes[i, j].hist(df[a]) + elif diagonal == 'kde': + y = df[a] + gkde = stats.gaussian_kde(y) + ind = np.linspace(min(y), max(y), 1000) + axes[i, j].plot(ind, gkde.evaluate(ind), **kwds) + axes[i, j].yaxis.set_visible(False) + axes[i, j].xaxis.set_visible(False) + if i == 0 and j == 0: + axes[i, j].yaxis.set_ticks_position('left') + axes[i, j].yaxis.set_label_position('left') + axes[i, j].yaxis.set_visible(True) + if i == n - 1 and j == n - 1: + axes[i, j].yaxis.set_ticks_position('right') + axes[i, j].yaxis.set_label_position('right') + axes[i, j].yaxis.set_visible(True) + else: + axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) + axes[i, j].set_xlabel('') + axes[i, j].set_ylabel('') + axes[i, j].set_xticklabels([]) + axes[i, j].set_yticklabels([]) + ticks = df.index + + is_datetype = ticks.inferred_type in ('datetime', 'date', 'datetime64') - if ticks.is_numeric() or is_datetype: - """ - Matplotlib supports numeric values or datetime objects as - xaxis values. Taking LBYL approach here, by the time - matplotlib raises exception when using non numeric/datetime - values for xaxis, several actions are already taken by plt. - """ - ticks = ticks._mpl_repr() - - # setup labels - if i == 0 and j % 2 == 1: - axes[i, j].set_xlabel(b, visible=True) - #axes[i, j].xaxis.set_visible(True) - axes[i, j].set_xlabel(b) - axes[i, j].set_xticklabels(ticks) - axes[i, j].xaxis.set_ticks_position('top') - axes[i, j].xaxis.set_label_position('top') - if i == n - 1 and j % 2 == 0: - axes[i, j].set_xlabel(b, visible=True) - #axes[i, j].xaxis.set_visible(True) - axes[i, j].set_xlabel(b) - axes[i, j].set_xticklabels(ticks) - axes[i, j].xaxis.set_ticks_position('bottom') - axes[i, j].xaxis.set_label_position('bottom') - if j == 0 and i % 2 == 0: - axes[i, j].set_ylabel(a, visible=True) - #axes[i, j].yaxis.set_visible(True) - axes[i, j].set_ylabel(a) - axes[i, j].set_yticklabels(ticks) - axes[i, j].yaxis.set_ticks_position('left') - axes[i, j].yaxis.set_label_position('left') - if j == n - 1 and i % 2 == 1: - axes[i, j].set_ylabel(a, visible=True) - #axes[i, j].yaxis.set_visible(True) - axes[i, j].set_ylabel(a) - axes[i, j].set_yticklabels(ticks) - axes[i, j].yaxis.set_ticks_position('right') - axes[i, j].yaxis.set_label_position('right') + if ticks.is_numeric() or is_datetype: + """ + Matplotlib supports numeric values or datetime objects as + xaxis values. Taking LBYL approach here, by the time + matplotlib raises exception when using non numeric/datetime + values for xaxis, several actions are already taken by plt. + """ + ticks = ticks._mpl_repr() + + # setup labels + if i == 0 and j % 2 == 1: + axes[i, j].set_xlabel(b, visible=True) + #axes[i, j].xaxis.set_visible(True) + axes[i, j].set_xlabel(b) + axes[i, j].set_xticklabels(ticks) + axes[i, j].xaxis.set_ticks_position('top') + axes[i, j].xaxis.set_label_position('top') + if i == n - 1 and j % 2 == 0: + axes[i, j].set_xlabel(b, visible=True) + #axes[i, j].xaxis.set_visible(True) + axes[i, j].set_xlabel(b) + axes[i, j].set_xticklabels(ticks) + axes[i, j].xaxis.set_ticks_position('bottom') + axes[i, j].xaxis.set_label_position('bottom') + if j == 0 and i % 2 == 0: + axes[i, j].set_ylabel(a, visible=True) + #axes[i, j].yaxis.set_visible(True) + axes[i, j].set_ylabel(a) + axes[i, j].set_yticklabels(ticks) + axes[i, j].yaxis.set_ticks_position('left') + axes[i, j].yaxis.set_label_position('left') + if j == n - 1 and i % 2 == 1: + axes[i, j].set_ylabel(a, visible=True) + #axes[i, j].yaxis.set_visible(True) + axes[i, j].set_ylabel(a) + axes[i, j].set_yticklabels(ticks) + axes[i, j].yaxis.set_ticks_position('right') + axes[i, j].yaxis.set_label_position('right') axes[i, j].grid(b=grid) - # ensure {x,y}lim off diagonal are the same as diagonal - for i in range(n): - for j in range(n): - if i != j: - axes[i, j].set_xlim(axes[j, j].get_xlim()) - axes[i, j].set_ylim(axes[i, i].get_ylim()) - return axes def _gca(): From 31809a62ffcbb2923adea07b7a4d0e2d53628eec Mon Sep 17 00:00:00 2001 From: Vytautas Jancauskas Date: Mon, 14 May 2012 23:47:07 +0300 Subject: [PATCH 2/5] Changed xtick, ytick labels --- pandas/tools/plotting.py | 105 ++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 57 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index ebb00e0e2efe9..857c1c7145bd5 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -46,65 +46,56 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, gkde = stats.gaussian_kde(y) ind = np.linspace(min(y), max(y), 1000) axes[i, j].plot(ind, gkde.evaluate(ind), **kwds) - axes[i, j].yaxis.set_visible(False) - axes[i, j].xaxis.set_visible(False) - if i == 0 and j == 0: - axes[i, j].yaxis.set_ticks_position('left') - axes[i, j].yaxis.set_label_position('left') - axes[i, j].yaxis.set_visible(True) - if i == n - 1 and j == n - 1: - axes[i, j].yaxis.set_ticks_position('right') - axes[i, j].yaxis.set_label_position('right') - axes[i, j].yaxis.set_visible(True) else: axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) - axes[i, j].set_xlabel('') - axes[i, j].set_ylabel('') - axes[i, j].set_xticklabels([]) - axes[i, j].set_yticklabels([]) - ticks = df.index - - is_datetype = ticks.inferred_type in ('datetime', 'date', - 'datetime64') - - if ticks.is_numeric() or is_datetype: - """ - Matplotlib supports numeric values or datetime objects as - xaxis values. Taking LBYL approach here, by the time - matplotlib raises exception when using non numeric/datetime - values for xaxis, several actions are already taken by plt. - """ - ticks = ticks._mpl_repr() - - # setup labels - if i == 0 and j % 2 == 1: - axes[i, j].set_xlabel(b, visible=True) - #axes[i, j].xaxis.set_visible(True) - axes[i, j].set_xlabel(b) - axes[i, j].set_xticklabels(ticks) - axes[i, j].xaxis.set_ticks_position('top') - axes[i, j].xaxis.set_label_position('top') - if i == n - 1 and j % 2 == 0: - axes[i, j].set_xlabel(b, visible=True) - #axes[i, j].xaxis.set_visible(True) - axes[i, j].set_xlabel(b) - axes[i, j].set_xticklabels(ticks) - axes[i, j].xaxis.set_ticks_position('bottom') - axes[i, j].xaxis.set_label_position('bottom') - if j == 0 and i % 2 == 0: - axes[i, j].set_ylabel(a, visible=True) - #axes[i, j].yaxis.set_visible(True) - axes[i, j].set_ylabel(a) - axes[i, j].set_yticklabels(ticks) - axes[i, j].yaxis.set_ticks_position('left') - axes[i, j].yaxis.set_label_position('left') - if j == n - 1 and i % 2 == 1: - axes[i, j].set_ylabel(a, visible=True) - #axes[i, j].yaxis.set_visible(True) - axes[i, j].set_ylabel(a) - axes[i, j].set_yticklabels(ticks) - axes[i, j].yaxis.set_ticks_position('right') - axes[i, j].yaxis.set_label_position('right') + + axes[i, j].set_xlabel('') + axes[i, j].set_ylabel('') + axes[i, j].set_xticklabels([]) + axes[i, j].set_yticklabels([]) + ticks = df.index + + is_datetype = ticks.inferred_type in ('datetime', 'date', + 'datetime64') + + if ticks.is_numeric() or is_datetype: + """ + Matplotlib supports numeric values or datetime objects as + xaxis values. Taking LBYL approach here, by the time + matplotlib raises exception when using non numeric/datetime + values for xaxis, several actions are already taken by plt. + """ + ticks = ticks._mpl_repr() + + # setup labels + if i == 0 and j % 2 == 1: + axes[i, j].set_xlabel(b, visible=True) + #axes[i, j].xaxis.set_visible(True) + axes[i, j].set_xlabel(b) + axes[i, j].set_xticklabels(ticks) + axes[i, j].xaxis.set_ticks_position('top') + axes[i, j].xaxis.set_label_position('top') + if i == n - 1 and j % 2 == 0: + axes[i, j].set_xlabel(b, visible=True) + #axes[i, j].xaxis.set_visible(True) + axes[i, j].set_xlabel(b) + axes[i, j].set_xticklabels(ticks) + axes[i, j].xaxis.set_ticks_position('bottom') + axes[i, j].xaxis.set_label_position('bottom') + if j == 0 and i % 2 == 0: + axes[i, j].set_ylabel(a, visible=True) + #axes[i, j].yaxis.set_visible(True) + axes[i, j].set_ylabel(a) + axes[i, j].set_yticklabels(ticks) + axes[i, j].yaxis.set_ticks_position('left') + axes[i, j].yaxis.set_label_position('left') + if j == n - 1 and i % 2 == 1: + axes[i, j].set_ylabel(a, visible=True) + #axes[i, j].yaxis.set_visible(True) + axes[i, j].set_ylabel(a) + axes[i, j].set_yticklabels(ticks) + axes[i, j].yaxis.set_ticks_position('right') + axes[i, j].yaxis.set_label_position('right') axes[i, j].grid(b=grid) From c7dcecb98d4fe9f10e3dd62a608ed8da5a0527d3 Mon Sep 17 00:00:00 2001 From: Vytautas Jancauskas Date: Tue, 15 May 2012 00:12:52 +0300 Subject: [PATCH 3/5] Added simple test cases --- pandas/tests/test_graphics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 8e987f35d42e7..6fe1f93448671 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -214,6 +214,8 @@ def scat(**kwds): _check_plot_works(scat) _check_plot_works(scat, marker='+') _check_plot_works(scat, vmin=0) + _check_plot_works(scat, diagonal='kde') + _check_plot_works(scat, diagonal='hist') def scat2(x, y, by=None, ax=None, figsize=None): return plt.scatter_plot(df, x, y, by, ax, figsize=None) From 555ad56e6e7ed290401dbbd110359dee528f520e Mon Sep 17 00:00:00 2001 From: Vytautas Jancauskas Date: Wed, 16 May 2012 18:44:14 +0300 Subject: [PATCH 4/5] Updated plotting.py scatter_matrix docstring to describe all the parameters --- pandas/tools/plotting.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 857c1c7145bd5..6e0bad14a580d 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -19,6 +19,13 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, Parameters ---------- + alpha : amount of transparency applied + figsize : a tuple (width, height) in inches + ax : Matplotlib axis object + grid : setting this to True will show the grid + diagonal : pick between 'kde' and 'hist' for + either Kernel Density Estimation or Histogram + plon in the diagonal kwds : other plotting keyword arguments To be passed to scatter function @@ -48,7 +55,7 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, axes[i, j].plot(ind, gkde.evaluate(ind), **kwds) else: axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) - + axes[i, j].set_xlabel('') axes[i, j].set_ylabel('') axes[i, j].set_xticklabels([]) From 552879256dfea4931379489a7e38821698118f05 Mon Sep 17 00:00:00 2001 From: Vytautas Jancauskas Date: Wed, 16 May 2012 19:12:16 +0300 Subject: [PATCH 5/5] Added scatter_matrix examples to visualization.rst --- doc/source/visualization.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index be969f3796935..6c035b816a9e9 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -241,5 +241,8 @@ Scatter plot matrix from pandas.tools.plotting import scatter_matrix df = DataFrame(np.random.randn(1000, 4), columns=['a', 'b', 'c', 'd']) - @savefig scatter_matrix_ex.png width=6in - scatter_matrix(df, alpha=0.2, figsize=(8, 8)) + @savefig scatter_matrix_kde.png width=6in + scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='kde') + + @savefig scatter_matrix_hist.png width=6in + scatter_matrix(df, alpha=0.2, figsize=(8, 8), diagonal='hist') \ No newline at end of file