diff --git a/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py b/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py index 9b5bf0a0b17..47e22b72240 100644 --- a/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py +++ b/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py @@ -103,10 +103,9 @@ def create_annotated_heatmap( # validate colorscale colorscale_validator = ColorscaleValidator() colorscale = colorscale_validator.validate_coerce(colorscale) - annotations = _AnnotatedHeatmap( z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs - ).make_annotations() + ).make_annotations(**kwargs) if x or y: trace = dict( @@ -136,7 +135,7 @@ def create_annotated_heatmap( layout = dict( annotations=annotations, xaxis=dict( - ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False + ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False, ), yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False), ) @@ -211,7 +210,6 @@ def get_text_color(self): "Blues", "YIGnBu", "YIOrRd", - "RdBu", "Picnic", "Jet", "Hot", @@ -264,22 +262,21 @@ def get_text_color(self): max_text_color = black return min_text_color, max_text_color - def get_z_mid(self): + def get_z_min_max(self): """ - Get the mid value of z matrix + Get the min and max value of z matrix - :rtype (float) z_avg: average val from z matrix + :rtype (tuple): min and max val from z matrix """ if np and isinstance(self.z, np.ndarray): - z_min = np.amin(self.z) - z_max = np.amax(self.z) + z_min = np.nanmin(self.z) + z_max = np.nanmax(self.z) else: z_min = min([v for row in self.z for v in row]) z_max = max([v for row in self.z for v in row]) - z_mid = (z_max + z_min) / 2 - return z_mid + return z_min, z_max - def make_annotations(self): + def make_annotations(self, zmin=None, zmid=None, zmax=None): """ Get annotations for each cell of the heatmap with graph_objs.Annotation @@ -287,11 +284,26 @@ def make_annotations(self): the heatmap """ min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) - z_mid = _AnnotatedHeatmap.get_z_mid(self) + if zmin is None or zmax is None: + zmin, zmax = _AnnotatedHeatmap.get_z_min_max(self) + if zmid is None: + zmid = (zmax + zmin) / 2 + if min_text_color == max_text_color: + # diverging colorscale + mid_text_color = "#000000" + get_font_color = ( + lambda val: mid_text_color + if (zmin + zmid) / 2 < val < (zmid + zmax) / 2 + else min_text_color + ) + else: + get_font_color = ( + lambda val: min_text_color if val < zmid else max_text_color + ) annotations = [] for n, row in enumerate(self.z): for m, val in enumerate(row): - font_color = min_text_color if val < z_mid else max_text_color + font_color = get_font_color(val) annotations.append( graph_objs.layout.Annotation( text=str(self.annotation_text[n][m]),