1from __future__ import absolute_import, division 2 3from plotly import exceptions, optional_imports 4import plotly.colors as clrs 5from plotly.figure_factory import utils 6from plotly.graph_objs import graph_objs 7from plotly.validators.heatmap import ColorscaleValidator 8 9# Optional imports, may be None for users that only use our core functionality. 10np = optional_imports.get_module("numpy") 11 12 13def validate_annotated_heatmap(z, x, y, annotation_text): 14 """ 15 Annotated-heatmap-specific validations 16 17 Check that if a text matrix is supplied, it has the same 18 dimensions as the z matrix. 19 20 See FigureFactory.create_annotated_heatmap() for params 21 22 :raises: (PlotlyError) If z and text matrices do not have the same 23 dimensions. 24 """ 25 if annotation_text is not None and isinstance(annotation_text, list): 26 utils.validate_equal_length(z, annotation_text) 27 for lst in range(len(z)): 28 if len(z[lst]) != len(annotation_text[lst]): 29 raise exceptions.PlotlyError( 30 "z and text should have the " "same dimensions" 31 ) 32 33 if x: 34 if len(x) != len(z[0]): 35 raise exceptions.PlotlyError( 36 "oops, the x list that you " 37 "provided does not match the " 38 "width of your z matrix " 39 ) 40 41 if y: 42 if len(y) != len(z): 43 raise exceptions.PlotlyError( 44 "oops, the y list that you " 45 "provided does not match the " 46 "length of your z matrix " 47 ) 48 49 50def create_annotated_heatmap( 51 z, 52 x=None, 53 y=None, 54 annotation_text=None, 55 colorscale="Plasma", 56 font_colors=None, 57 showscale=False, 58 reversescale=False, 59 **kwargs 60): 61 """ 62 Function that creates annotated heatmaps 63 64 This function adds annotations to each cell of the heatmap. 65 66 :param (list[list]|ndarray) z: z matrix to create heatmap. 67 :param (list) x: x axis labels. 68 :param (list) y: y axis labels. 69 :param (list[list]|ndarray) annotation_text: Text strings for 70 annotations. Should have the same dimensions as the z matrix. If no 71 text is added, the values of the z matrix are annotated. Default = 72 z matrix values. 73 :param (list|str) colorscale: heatmap colorscale. 74 :param (list) font_colors: List of two color strings: [min_text_color, 75 max_text_color] where min_text_color is applied to annotations for 76 heatmap values < (max_value - min_value)/2. If font_colors is not 77 defined, the colors are defined logically as black or white 78 depending on the heatmap's colorscale. 79 :param (bool) showscale: Display colorscale. Default = False 80 :param (bool) reversescale: Reverse colorscale. Default = False 81 :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. 82 These kwargs describe other attributes about the annotated Heatmap 83 trace such as the colorscale. For more information on valid kwargs 84 call help(plotly.graph_objs.Heatmap) 85 86 Example 1: Simple annotated heatmap with default configuration 87 88 >>> import plotly.figure_factory as ff 89 90 >>> z = [[0.300000, 0.00000, 0.65, 0.300000], 91 ... [1, 0.100005, 0.45, 0.4300], 92 ... [0.300000, 0.00000, 0.65, 0.300000], 93 ... [1, 0.100005, 0.45, 0.00000]] 94 95 >>> fig = ff.create_annotated_heatmap(z) 96 >>> fig.show() 97 """ 98 99 # Avoiding mutables in the call signature 100 font_colors = font_colors if font_colors is not None else [] 101 validate_annotated_heatmap(z, x, y, annotation_text) 102 103 # validate colorscale 104 colorscale_validator = ColorscaleValidator() 105 colorscale = colorscale_validator.validate_coerce(colorscale) 106 107 annotations = _AnnotatedHeatmap( 108 z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs 109 ).make_annotations() 110 111 if x or y: 112 trace = dict( 113 type="heatmap", 114 z=z, 115 x=x, 116 y=y, 117 colorscale=colorscale, 118 showscale=showscale, 119 reversescale=reversescale, 120 **kwargs 121 ) 122 layout = dict( 123 annotations=annotations, 124 xaxis=dict(ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)"), 125 yaxis=dict(ticks="", dtick=1, ticksuffix=" "), 126 ) 127 else: 128 trace = dict( 129 type="heatmap", 130 z=z, 131 colorscale=colorscale, 132 showscale=showscale, 133 reversescale=reversescale, 134 **kwargs 135 ) 136 layout = dict( 137 annotations=annotations, 138 xaxis=dict( 139 ticks="", side="top", gridcolor="rgb(0, 0, 0)", showticklabels=False 140 ), 141 yaxis=dict(ticks="", ticksuffix=" ", showticklabels=False), 142 ) 143 144 data = [trace] 145 146 return graph_objs.Figure(data=data, layout=layout) 147 148 149def to_rgb_color_list(color_str, default): 150 if "rgb" in color_str: 151 return [int(v) for v in color_str.strip("rgb()").split(",")] 152 elif "#" in color_str: 153 return clrs.hex_to_rgb(color_str) 154 else: 155 return default 156 157 158def should_use_black_text(background_color): 159 return ( 160 background_color[0] * 0.299 161 + background_color[1] * 0.587 162 + background_color[2] * 0.114 163 ) > 186 164 165 166class _AnnotatedHeatmap(object): 167 """ 168 Refer to TraceFactory.create_annotated_heatmap() for docstring 169 """ 170 171 def __init__( 172 self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs 173 ): 174 175 self.z = z 176 if x: 177 self.x = x 178 else: 179 self.x = range(len(z[0])) 180 if y: 181 self.y = y 182 else: 183 self.y = range(len(z)) 184 if annotation_text is not None: 185 self.annotation_text = annotation_text 186 else: 187 self.annotation_text = self.z 188 self.colorscale = colorscale 189 self.reversescale = reversescale 190 self.font_colors = font_colors 191 192 if np and isinstance(self.z, np.ndarray): 193 self.zmin = np.amin(self.z) 194 self.zmax = np.amax(self.z) 195 else: 196 self.zmin = min([v for row in self.z for v in row]) 197 self.zmax = max([v for row in self.z for v in row]) 198 199 if kwargs.get("zmin", None) is not None: 200 self.zmin = kwargs["zmin"] 201 if kwargs.get("zmax", None) is not None: 202 self.zmax = kwargs["zmax"] 203 204 self.zmid = (self.zmax + self.zmin) / 2 205 206 if kwargs.get("zmid", None) is not None: 207 self.zmid = kwargs["zmid"] 208 209 def get_text_color(self): 210 """ 211 Get font color for annotations. 212 213 The annotated heatmap can feature two text colors: min_text_color and 214 max_text_color. The min_text_color is applied to annotations for 215 heatmap values < (max_value - min_value)/2. The user can define these 216 two colors. Otherwise the colors are defined logically as black or 217 white depending on the heatmap's colorscale. 218 219 :rtype (string, string) min_text_color, max_text_color: text 220 color for annotations for heatmap values < 221 (max_value - min_value)/2 and text color for annotations for 222 heatmap values >= (max_value - min_value)/2 223 """ 224 # Plotly colorscales ranging from a lighter shade to a darker shade 225 colorscales = [ 226 "Greys", 227 "Greens", 228 "Blues", 229 "YIGnBu", 230 "YIOrRd", 231 "RdBu", 232 "Picnic", 233 "Jet", 234 "Hot", 235 "Blackbody", 236 "Earth", 237 "Electric", 238 "Viridis", 239 "Cividis", 240 ] 241 # Plotly colorscales ranging from a darker shade to a lighter shade 242 colorscales_reverse = ["Reds"] 243 244 white = "#FFFFFF" 245 black = "#000000" 246 if self.font_colors: 247 min_text_color = self.font_colors[0] 248 max_text_color = self.font_colors[-1] 249 elif self.colorscale in colorscales and self.reversescale: 250 min_text_color = black 251 max_text_color = white 252 elif self.colorscale in colorscales: 253 min_text_color = white 254 max_text_color = black 255 elif self.colorscale in colorscales_reverse and self.reversescale: 256 min_text_color = white 257 max_text_color = black 258 elif self.colorscale in colorscales_reverse: 259 min_text_color = black 260 max_text_color = white 261 elif isinstance(self.colorscale, list): 262 263 min_col = to_rgb_color_list(self.colorscale[0][1], [255, 255, 255]) 264 max_col = to_rgb_color_list(self.colorscale[-1][1], [255, 255, 255]) 265 266 # swap min/max colors if reverse scale 267 if self.reversescale: 268 min_col, max_col = max_col, min_col 269 270 if should_use_black_text(min_col): 271 min_text_color = black 272 else: 273 min_text_color = white 274 275 if should_use_black_text(max_col): 276 max_text_color = black 277 else: 278 max_text_color = white 279 else: 280 min_text_color = black 281 max_text_color = black 282 return min_text_color, max_text_color 283 284 def make_annotations(self): 285 """ 286 Get annotations for each cell of the heatmap with graph_objs.Annotation 287 288 :rtype (list[dict]) annotations: list of annotations for each cell of 289 the heatmap 290 """ 291 min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) 292 annotations = [] 293 for n, row in enumerate(self.z): 294 for m, val in enumerate(row): 295 font_color = min_text_color if val < self.zmid else max_text_color 296 annotations.append( 297 graph_objs.layout.Annotation( 298 text=str(self.annotation_text[n][m]), 299 x=self.x[m], 300 y=self.y[n], 301 xref="x1", 302 yref="y1", 303 font=dict(color=font_color), 304 showarrow=False, 305 ) 306 ) 307 return annotations 308