diff --git a/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py b/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py index 9b5bf0a0b17..e6253c43afe 100644 --- a/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py +++ b/packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py @@ -189,6 +189,23 @@ def __init__( self.reversescale = reversescale self.font_colors = font_colors + if np and isinstance(self.z, np.ndarray): + self.zmin = np.amin(self.z) + self.zmax = np.amax(self.z) + else: + self.zmin = min([v for row in self.z for v in row]) + self.zmax = max([v for row in self.z for v in row]) + + if kwargs.get("zmin", None) is not None: + self.zmin = kwargs["zmin"] + if kwargs.get("zmax", None) is not None: + self.zmax = kwargs["zmax"] + + self.zmid = (self.zmax + self.zmin) / 2 + + if kwargs.get("zmid", None) is not None: + self.zmid = kwargs["zmid"] + def get_text_color(self): """ Get font color for annotations. @@ -264,21 +281,6 @@ def get_text_color(self): max_text_color = black return min_text_color, max_text_color - def get_z_mid(self): - """ - Get the mid value of z matrix - - :rtype (float) z_avg: average val from z matrix - """ - if np and isinstance(self.z, np.ndarray): - z_min = np.amin(self.z) - z_max = np.amax(self.z) - else: - z_min = min([v for row in self.z for v in row]) - z_max = max([v for row in self.z for v in row]) - z_mid = (z_max + z_min) / 2 - return z_mid - def make_annotations(self): """ Get annotations for each cell of the heatmap with graph_objs.Annotation @@ -287,11 +289,10 @@ def make_annotations(self): the heatmap """ min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) - z_mid = _AnnotatedHeatmap.get_z_mid(self) annotations = [] for n, row in enumerate(self.z): for m, val in enumerate(row): - font_color = min_text_color if val < z_mid else max_text_color + font_color = min_text_color if val < self.zmid else max_text_color annotations.append( graph_objs.layout.Annotation( text=str(self.annotation_text[n][m]),