1from __future__ import absolute_import, division
2
3from plotly import exceptions, optional_imports
4import plotly.colors as clrs
5from plotly.figure_factory import utils
6from plotly.graph_objs import graph_objs
7from plotly.validators.heatmap import ColorscaleValidator
8
9# Optional imports, may be None for users that only use our core functionality.
10np = optional_imports.get_module("numpy")
11
12
13def validate_annotated_heatmap(z, x, y, annotation_text):
14    """
15    Annotated-heatmap-specific validations
16
17    Check that if a text matrix is supplied, it has the same
18    dimensions as the z matrix.
19
20    See FigureFactory.create_annotated_heatmap() for params
21
22    :raises: (PlotlyError) If z and text matrices do not  have the same
23        dimensions.
24    """
25    if annotation_text is not None and isinstance(annotation_text, list):
26        utils.validate_equal_length(z, annotation_text)
27        for lst in range(len(z)):
28            if len(z[lst]) != len(annotation_text[lst]):
29                raise exceptions.PlotlyError(
30                    "z and text should have the " "same dimensions"
31                )
32
33    if x:
34        if len(x) != len(z[0]):
35            raise exceptions.PlotlyError(
36                "oops, the x list that you "
37                "provided does not match the "
38                "width of your z matrix "
39            )
40
41    if y:
42        if len(y) != len(z):
43            raise exceptions.PlotlyError(
44                "oops, the y list that you "
45                "provided does not match the "
46                "length of your z matrix "
47            )
48
49
50def create_annotated_heatmap(
51    z,
52    x=None,
53    y=None,
54    annotation_text=None,
55    colorscale="Plasma",
56    font_colors=None,
57    showscale=False,
58    reversescale=False,
59    **kwargs
60):
61    """
62    Function that creates annotated heatmaps
63
64    This function adds annotations to each cell of the heatmap.
65
66    :param (list[list]|ndarray) z: z matrix to create heatmap.
67    :param (list) x: x axis labels.
68    :param (list) y: y axis labels.
69    :param (list[list]|ndarray) annotation_text: Text strings for
70        annotations. Should have the same dimensions as the z matrix. If no
71        text is added, the values of the z matrix are annotated. Default =
72        z matrix values.
73    :param (list|str) colorscale: heatmap colorscale.
74    :param (list) font_colors: List of two color strings: [min_text_color,
75        max_text_color] where min_text_color is applied to annotations for
76        heatmap values < (max_value - min_value)/2. If font_colors is not
77        defined, the colors are defined logically as black or white
78        depending on the heatmap's colorscale.
79    :param (bool) showscale: Display colorscale. Default = False
80    :param (bool) reversescale: Reverse colorscale. Default = False
81    :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
82        These kwargs describe other attributes about the annotated Heatmap
83        trace such as the colorscale. For more information on valid kwargs
84        call help(plotly.graph_objs.Heatmap)
85
86    Example 1: Simple annotated heatmap with default configuration
87
88    >>> import plotly.figure_factory as ff
89
90    >>> z = [[0.300000, 0.00000, 0.65, 0.300000],
91    ...      [1, 0.100005, 0.45, 0.4300],
92    ...      [0.300000, 0.00000, 0.65, 0.300000],
93    ...      [1, 0.100005, 0.45, 0.00000]]
94
95    >>> fig = ff.create_annotated_heatmap(z)
96    >>> fig.show()
97    """
98
99    # Avoiding mutables in the call signature
100    font_colors = font_colors if font_colors is not None else []
101    validate_annotated_heatmap(z, x, y, annotation_text)
102
103    # validate colorscale
104    colorscale_validator = ColorscaleValidator()
105    colorscale = colorscale_validator.validate_coerce(colorscale)
106
107    annotations = _AnnotatedHeatmap(
108        z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
109    ).make_annotations()
110
111    if x or y:
112        trace = dict(
113            type="heatmap",
114            z=z,
115            x=x,
116            y=y,
117            colorscale=colorscale,
118            showscale=showscale,
119            reversescale=reversescale,
120            **kwargs
121        )
122        layout = dict(
123            annotations=annotations,
124            xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"),
125            yaxis=dict(ticks="", dtick=1, ticksuffix="  "),
126        )
127    else:
128        trace = dict(
129            type="heatmap",
130            z=z,
131            colorscale=colorscale,
132            showscale=showscale,
133            reversescale=reversescale,
134            **kwargs
135        )
136        layout = dict(
137            annotations=annotations,
138            xaxis=dict(
139                ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False
140            ),
141            yaxis=dict(ticks="", ticksuffix="  ", showticklabels=False),
142        )
143
144    data = [trace]
145
146    return graph_objs.Figure(data=data, layout=layout)
147
148
149def to_rgb_color_list(color_str, default):
150    if "rgb" in color_str:
151        return [int(v) for v in color_str.strip("rgb()").split(",")]
152    elif "#" in color_str:
153        return clrs.hex_to_rgb(color_str)
154    else:
155        return default
156
157
158def should_use_black_text(background_color):
159    return (
160        background_color[0] * 0.299
161        + background_color[1] * 0.587
162        + background_color[2] * 0.114
163    ) > 186
164
165
166class _AnnotatedHeatmap(object):
167    """
168    Refer to TraceFactory.create_annotated_heatmap() for docstring
169    """
170
171    def __init__(
172        self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs
173    ):
174
175        self.z = z
176        if x:
177            self.x = x
178        else:
179            self.x = range(len(z[0]))
180        if y:
181            self.y = y
182        else:
183            self.y = range(len(z))
184        if annotation_text is not None:
185            self.annotation_text = annotation_text
186        else:
187            self.annotation_text = self.z
188        self.colorscale = colorscale
189        self.reversescale = reversescale
190        self.font_colors = font_colors
191
192        if np and isinstance(self.z, np.ndarray):
193            self.zmin = np.amin(self.z)
194            self.zmax = np.amax(self.z)
195        else:
196            self.zmin = min([v for row in self.z for v in row])
197            self.zmax = max([v for row in self.z for v in row])
198
199        if kwargs.get("zmin", None) is not None:
200            self.zmin = kwargs["zmin"]
201        if kwargs.get("zmax", None) is not None:
202            self.zmax = kwargs["zmax"]
203
204        self.zmid = (self.zmax + self.zmin) / 2
205
206        if kwargs.get("zmid", None) is not None:
207            self.zmid = kwargs["zmid"]
208
209    def get_text_color(self):
210        """
211        Get font color for annotations.
212
213        The annotated heatmap can feature two text colors: min_text_color and
214        max_text_color. The min_text_color is applied to annotations for
215        heatmap values < (max_value - min_value)/2. The user can define these
216        two colors. Otherwise the colors are defined logically as black or
217        white depending on the heatmap's colorscale.
218
219        :rtype (string, string) min_text_color, max_text_color: text
220            color for annotations for heatmap values <
221            (max_value - min_value)/2 and text color for annotations for
222            heatmap values >= (max_value - min_value)/2
223        """
224        # Plotly colorscales ranging from a lighter shade to a darker shade
225        colorscales = [
226            "Greys",
227            "Greens",
228            "Blues",
229            "YIGnBu",
230            "YIOrRd",
231            "RdBu",
232            "Picnic",
233            "Jet",
234            "Hot",
235            "Blackbody",
236            "Earth",
237            "Electric",
238            "Viridis",
239            "Cividis",
240        ]
241        # Plotly colorscales ranging from a darker shade to a lighter shade
242        colorscales_reverse = ["Reds"]
243
244        white = "#FFFFFF"
245        black = "#000000"
246        if self.font_colors:
247            min_text_color = self.font_colors[0]
248            max_text_color = self.font_colors[-1]
249        elif self.colorscale in colorscales and self.reversescale:
250            min_text_color = black
251            max_text_color = white
252        elif self.colorscale in colorscales:
253            min_text_color = white
254            max_text_color = black
255        elif self.colorscale in colorscales_reverse and self.reversescale:
256            min_text_color = white
257            max_text_color = black
258        elif self.colorscale in colorscales_reverse:
259            min_text_color = black
260            max_text_color = white
261        elif isinstance(self.colorscale, list):
262
263            min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255])
264            max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255])
265
266            # swap min/max colors if reverse scale
267            if self.reversescale:
268                min_col, max_col = max_col, min_col
269
270            if should_use_black_text(min_col):
271                min_text_color = black
272            else:
273                min_text_color = white
274
275            if should_use_black_text(max_col):
276                max_text_color = black
277            else:
278                max_text_color = white
279        else:
280            min_text_color = black
281            max_text_color = black
282        return min_text_color, max_text_color
283
284    def make_annotations(self):
285        """
286        Get annotations for each cell of the heatmap with graph_objs.Annotation
287
288        :rtype (list[dict]) annotations: list of annotations for each cell of
289            the heatmap
290        """
291        min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
292        annotations = []
293        for n, row in enumerate(self.z):
294            for m, val in enumerate(row):
295                font_color = min_text_color if val < self.zmid else max_text_color
296                annotations.append(
297                    graph_objs.layout.Annotation(
298                        text=str(self.annotation_text[n][m]),
299                        x=self.x[m],
300                        y=self.y[n],
301                        xref="x1",
302                        yref="y1",
303                        font=dict(color=font_color),
304                        showarrow=False,
305                    )
306                )
307        return annotations
308