Skip to content

Commit 1d2b988

Browse files
authored
Merge pull request #577 from grst/custom-dendrogram
Allow the use of custom linkage/distance functions for creating dendrograms.
2 parents 96b7f75 + 7f94b93 commit 1d2b988

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

plotly/tools.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6052,14 +6052,18 @@ def create_distplot(hist_data, group_labels,
60526052

60536053
@staticmethod
60546054
def create_dendrogram(X, orientation="bottom", labels=None,
6055-
colorscale=None):
6055+
colorscale=None, distfun=scs.distance.pdist,
6056+
linkagefun=lambda x: sch.linkage(x, 'complete')):
60566057
"""
60576058
BETA function that returns a dendrogram Plotly figure object.
60586059
60596060
:param (ndarray) X: Matrix of observations as array of arrays
60606061
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
60616062
:param (list) labels: List of axis category labels(observation labels)
60626063
:param (list) colorscale: Optional colorscale for dendrogram tree
6064+
:param (function) distfun: Function to compute the pairwise distance from the observations
6065+
:param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances
6066+
60636067
clusters
60646068
60656069
Example 1: Simple bottom oriented dendrogram
@@ -6115,7 +6119,8 @@ def create_dendrogram(X, orientation="bottom", labels=None,
61156119
if len(s) != 2:
61166120
exceptions.PlotlyError("X should be 2-dimensional array.")
61176121

6118-
dendrogram = _Dendrogram(X, orientation, labels, colorscale)
6122+
dendrogram = _Dendrogram(X, orientation, labels, colorscale,
6123+
distfun=distfun, linkagefun=linkagefun)
61196124

61206125
return {'layout': dendrogram.layout,
61216126
'data': dendrogram.data}
@@ -7042,7 +7047,8 @@ class _Dendrogram(FigureFactory):
70427047
"""Refer to FigureFactory.create_dendrogram() for docstring."""
70437048

70447049
def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
7045-
width="100%", height="100%", xaxis='xaxis', yaxis='yaxis'):
7050+
width="100%", height="100%", xaxis='xaxis', yaxis='yaxis',
7051+
distfun=scs.distance.pdist, linkagefun=lambda x: sch.linkage(x, 'complete')):
70467052
# TODO: protected until #282
70477053
from plotly.graph_objs import graph_objs
70487054
self.orientation = orientation
@@ -7065,7 +7071,7 @@ def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
70657071
self.sign[self.yaxis] = -1
70667072

70677073
(dd_traces, xvals, yvals,
7068-
ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale)
7074+
ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun)
70697075

70707076
self.labels = ordered_labels
70717077
self.leaves = leaves
@@ -7174,12 +7180,14 @@ def set_figure_layout(self, width, height):
71747180

71757181
return self.layout
71767182

7177-
def get_dendrogram_traces(self, X, colorscale):
7183+
def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun):
71787184
"""
71797185
Calculates all the elements needed for plotting a dendrogram.
71807186
71817187
:param (ndarray) X: Matrix of observations as array of arrays
71827188
:param (list) colorscale: Color scale for dendrogram tree clusters
7189+
:param (function) distfun: Function to compute the pairwise distance from the observations
7190+
:param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances
71837191
:rtype (tuple): Contains all the traces in the following order:
71847192
(a) trace_list: List of Plotly trace objects for dendrogram tree
71857193
(b) icoord: All X points of the dendrogram tree as array of arrays
@@ -7193,8 +7201,8 @@ def get_dendrogram_traces(self, X, colorscale):
71937201
"""
71947202
# TODO: protected until #282
71957203
from plotly.graph_objs import graph_objs
7196-
d = scs.distance.pdist(X)
7197-
Z = sch.linkage(d, method='complete')
7204+
d = distfun(X)
7205+
Z = linkagefun(d)
71987206
P = sch.dendrogram(Z, orientation=self.orientation,
71997207
labels=self.labels, no_plot=True)
72007208

0 commit comments

Comments
 (0)