1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import copy 15from dataclasses import astuple, dataclass 16from typing import ( 17 Any, 18 Dict, 19 List, 20 Mapping, 21 Optional, 22 overload, 23 Sequence, 24 SupportsFloat, 25 Tuple, 26 Union, 27) 28 29import matplotlib as mpl 30import matplotlib.collections as mpl_collections 31import matplotlib.pyplot as plt 32import numpy as np 33from mpl_toolkits import axes_grid1 34 35from cirq.devices import grid_qubit 36from cirq.vis import vis_utils 37 38QubitTuple = Tuple[grid_qubit.GridQubit, ...] 39 40Polygon = Sequence[Tuple[float, float]] 41 42 43@dataclass 44class Point: 45 x: float 46 y: float 47 48 def __iter__(self): 49 return iter(astuple(self)) 50 51 52@dataclass 53class PolygonUnit: 54 """Dataclass to store information about a single polygon unit to plot on the heatmap 55 56 For single (grid) qubit heatmaps, the polygon is a square. 57 For two (grid) qubit interaction heatmaps, the polygon is a hexagon. 58 59 Args: 60 polygon: Vertices of the polygon to plot. 61 value: The value for the heatmap coloring. 62 center: The center point of the polygon where annotation text should be printed. 63 annot: The annotation string to print on the coupler. 64 65 """ 66 67 polygon: Polygon 68 value: float 69 center: Point 70 annot: Optional[str] 71 72 73class Heatmap: 74 """Distribution of a value in 2D qubit lattice as a color map.""" 75 76 # pylint: disable=function-redefined 77 @overload 78 def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs): 79 pass 80 81 @overload 82 def __init__(self, value_map: Mapping[grid_qubit.GridQubit, SupportsFloat], **kwargs): 83 pass 84 85 # TODO(#3388) Add documentation for Args. 86 # pylint: disable=missing-param-doc 87 def __init__( 88 self, 89 value_map: Union[ 90 Mapping[QubitTuple, SupportsFloat], Mapping[grid_qubit.GridQubit, SupportsFloat] 91 ], 92 **kwargs, 93 ): 94 """2D qubit grid Heatmaps 95 96 Draw 2D qubit grid heatmap with Matplotlib with parameters to configure the properties of 97 the plot. 98 99 Args: 100 value_map: dictionary 101 A dictionary of qubits or QubitTuples as keys and corresponding magnitude as float 102 values. It corresponds to the data which should be plotted as a heatmap. 103 104 title: str, default = None 105 plot_colorbar: bool, default = True 106 107 annotation_map: dictionary, 108 A dictionary of QubitTuples as keys and corresponding annotation str as values. It 109 corresponds to the text that should be added on top of each heatmap polygon unit. 110 annotation_format: str, default = '.2g' 111 Formatting string using which annotation_map will be implicitly contstructed by 112 applying format(value, annotation_format) for each key in value_map. 113 This is ignored if annotation_map is explicitly specified. 114 annotation_text_kwargs: Matplotlib Text **kwargs, 115 116 colorbar_position: {'right', 'left', 'top', 'bottom'}, default = 'right' 117 colorbar_size: str, default = '5%' 118 colorbar_pad: str, default = '2%' 119 colorbar_options: Matplotlib colorbar **kwargs, default = None, 120 121 122 collection_options: Matplotlib PolyCollection **kwargs, default 123 {"cmap" : "viridis"} 124 vmin, vmax: colormap scaling floats, default = None 125 """ 126 self._value_map: Mapping[QubitTuple, SupportsFloat] = { 127 k if isinstance(k, tuple) else (k,): v for k, v in value_map.items() 128 } 129 self._validate_kwargs(kwargs) 130 if '_config' not in self.__dict__: 131 self._config: Dict[str, Any] = {} 132 self._config.update( 133 { 134 "plot_colorbar": True, 135 "colorbar_position": "right", 136 "colorbar_size": "5%", 137 "colorbar_pad": "2%", 138 "collection_options": {"cmap": "viridis"}, 139 "annotation_format": ".2g", 140 } 141 ) 142 self._config.update(kwargs) 143 144 # pylint: enable=function-redefined,missing-param-doc 145 def _extra_valid_kwargs(self) -> List[str]: 146 return [] 147 148 def _validate_kwargs(self, kwargs) -> None: 149 valid_colorbar_kwargs = [ 150 "plot_colorbar", 151 "colorbar_position", 152 "colorbar_size", 153 "colorbar_pad", 154 "colorbar_options", 155 ] 156 valid_collection_kwargs = [ 157 "collection_options", 158 "vmin", 159 "vmax", 160 ] 161 valid_heatmap_kwargs = [ 162 "title", 163 "annotation_map", 164 "annotation_text_kwargs", 165 "annotation_format", 166 ] 167 valid_kwargs = ( 168 valid_colorbar_kwargs 169 + valid_collection_kwargs 170 + valid_heatmap_kwargs 171 + self._extra_valid_kwargs() 172 ) 173 if any([k not in valid_kwargs for k in kwargs]): 174 invalid_args = ", ".join([k for k in kwargs if k not in valid_kwargs]) 175 raise ValueError(f"Received invalid argument(s): {invalid_args}") 176 177 def update_config(self, **kwargs) -> 'Heatmap': 178 """Add/Modify **kwargs args passed during initialisation.""" 179 self._validate_kwargs(kwargs) 180 self._config.update(kwargs) 181 return self 182 183 def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]: 184 qubit = qubits[0] 185 x, y = float(qubit.row), float(qubit.col) 186 return ( 187 [ 188 (y - 0.5, x - 0.5), 189 (y - 0.5, x + 0.5), 190 (y + 0.5, x + 0.5), 191 (y + 0.5, x - 0.5), 192 ], 193 Point(y, x), 194 ) 195 196 def _get_annotation_value(self, key, value) -> Optional[str]: 197 if self._config.get('annotation_map'): 198 return self._config['annotation_map'].get(key) 199 elif self._config.get('annotation_format'): 200 try: 201 return format(value, self._config['annotation_format']) 202 except: 203 return format(float(value), self._config['annotation_format']) 204 else: 205 return None 206 207 def _get_polygon_units(self) -> List[PolygonUnit]: 208 polygon_unit_list: List[PolygonUnit] = [] 209 for qubits, value in sorted(self._value_map.items()): 210 polygon, center = self._qubits_to_polygon(qubits) 211 polygon_unit_list.append( 212 PolygonUnit( 213 polygon=polygon, 214 center=center, 215 value=float(value), 216 annot=self._get_annotation_value(qubits, value), 217 ) 218 ) 219 return polygon_unit_list 220 221 def _plot_colorbar( 222 self, mappable: mpl.cm.ScalarMappable, ax: plt.Axes 223 ) -> mpl.colorbar.Colorbar: 224 """Plots the colorbar. Internal.""" 225 colorbar_ax = axes_grid1.make_axes_locatable(ax).append_axes( 226 position=self._config['colorbar_position'], 227 size=self._config['colorbar_size'], 228 pad=self._config['colorbar_pad'], 229 ) 230 position = self._config['colorbar_position'] 231 orien = 'vertical' if position in ('left', 'right') else 'horizontal' 232 colorbar = ax.figure.colorbar( 233 mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {}) 234 ) 235 colorbar_ax.tick_params(axis='y', direction='out') 236 return colorbar 237 238 def _write_annotations( 239 self, 240 centers_and_annot: List[Tuple[Point, Optional[str]]], 241 collection: mpl_collections.Collection, 242 ax: plt.Axes, 243 ) -> None: 244 """Writes annotations to the center of cells. Internal.""" 245 for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()): 246 # Calculate the center of the cell, assuming that it is a square 247 # centered at (x=col, y=row). 248 if not annotation: 249 continue 250 x, y = center 251 face_luminance = vis_utils.relative_luminance(facecolor) 252 text_color = 'black' if face_luminance > 0.4 else 'white' 253 text_kwargs = dict(color=text_color, ha="center", va="center") 254 text_kwargs.update(self._config.get('annotation_text_kwargs', {})) 255 ax.text(x, y, annotation, **text_kwargs) 256 257 def _plot_on_axis(self, ax: plt.Axes) -> mpl_collections.Collection: 258 # Step-1: Convert value_map to a list of polygons to plot. 259 polygon_list = self._get_polygon_units() 260 collection: mpl_collections.Collection = mpl_collections.PolyCollection( 261 [c.polygon for c in polygon_list], 262 **self._config.get('collection_options', {}), 263 ) 264 collection.set_clim(self._config.get('vmin'), self._config.get('vmax')) 265 collection.set_array(np.array([c.value for c in polygon_list])) 266 # Step-2: Plot the polygons 267 ax.add_collection(collection) 268 collection.update_scalarmappable() 269 # Step-3: Write annotation texts 270 if self._config.get('annotation_map') or self._config.get('annotation_format'): 271 self._write_annotations([(c.center, c.annot) for c in polygon_list], collection, ax) 272 ax.set(xlabel='column', ylabel='row') 273 # Step-4: Draw colorbar if applicable 274 if self._config.get('plot_colorbar'): 275 self._plot_colorbar(collection, ax) 276 # Step-5: Set min/max limits of x/y axis on the plot. 277 rows = set([q.row for qubits in self._value_map.keys() for q in qubits]) 278 cols = set([q.col for qubits in self._value_map.keys() for q in qubits]) 279 min_row, max_row = min(rows), max(rows) 280 min_col, max_col = min(cols), max(cols) 281 min_xtick = np.floor(min_col) 282 max_xtick = np.ceil(max_col) 283 ax.set_xticks(np.arange(min_xtick, max_xtick + 1)) 284 min_ytick = np.floor(min_row) 285 max_ytick = np.ceil(max_row) 286 ax.set_yticks(np.arange(min_ytick, max_ytick + 1)) 287 ax.set_xlim((min_xtick - 0.6, max_xtick + 0.6)) 288 ax.set_ylim((max_ytick + 0.6, min_ytick - 0.6)) 289 # Step-6: Set title 290 if self._config.get("title"): 291 ax.set_title(self._config["title"], fontweight='bold') 292 return collection 293 294 def plot( 295 self, ax: Optional[plt.Axes] = None, **kwargs: Any 296 ) -> Tuple[plt.Axes, mpl_collections.Collection]: 297 """Plots the heatmap on the given Axes. 298 Args: 299 ax: the Axes to plot on. If not given, a new figure is created, 300 plotted on, and shown. 301 kwargs: The optional keyword arguments are used to temporarily 302 override the values present in the heatmap config. See 303 __init__ for more details on the allowed arguments. 304 Returns: 305 A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that 306 is plotted on. ``collection`` is the collection of paths drawn and filled. 307 """ 308 show_plot = not ax 309 if not ax: 310 fig, ax = plt.subplots(figsize=(8, 8)) 311 original_config = copy.deepcopy(self._config) 312 self.update_config(**kwargs) 313 collection = self._plot_on_axis(ax) 314 if show_plot: 315 fig.show() 316 self._config = original_config 317 return (ax, collection) 318 319 320class TwoQubitInteractionHeatmap(Heatmap): 321 """Visualizing interactions between neighboring qubits on a 2D grid.""" 322 323 # TODO(#3388) Add documentation for Args. 324 # pylint: disable=missing-param-doc 325 def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs): 326 """Heatmap to display two-qubit interaction fidelities. 327 328 Draw 2D qubit-qubit interaction heatmap with Matplotlib with arguments to configure the 329 properties of the plot. The valid argument list includes all arguments of cirq.vis.Heatmap() 330 plus the following. 331 332 Args: 333 coupler_margin: float, default = 0.03 334 coupler_width: float, default = 0.6 335 """ 336 self._config: Dict[str, Any] = { 337 "coupler_margin": 0.03, 338 "coupler_width": 0.6, 339 } 340 super().__init__(value_map, **kwargs) 341 342 # pylint: enable=missing-param-doc 343 def _extra_valid_kwargs(self) -> List[str]: 344 return ["coupler_margin", "coupler_width"] 345 346 def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]: 347 coupler_margin = self._config["coupler_margin"] 348 coupler_width = self._config["coupler_width"] 349 cwidth = coupler_width / 2.0 350 setback = 0.5 - cwidth 351 row1, col1 = map(float, (qubits[0].row, qubits[0].col)) 352 row2, col2 = map(float, (qubits[1].row, qubits[1].col)) 353 if abs(row1 - row2) + abs(col1 - col2) != 1: 354 raise ValueError( 355 f"{qubits[0]}-{qubits[1]} is not supported because they are not nearest neighbors" 356 ) 357 if coupler_width <= 0: 358 polygon: Polygon = [] 359 elif row1 == row2: # horizontal 360 col1, col2 = min(col1, col2), max(col1, col2) 361 col_center = (col1 + col2) / 2.0 362 polygon = [ 363 (col1 + coupler_margin, row1), 364 (col_center - setback, row1 + cwidth - coupler_margin), 365 (col_center + setback, row1 + cwidth - coupler_margin), 366 (col2 - coupler_margin, row2), 367 (col_center + setback, row1 - cwidth + coupler_margin), 368 (col_center - setback, row1 - cwidth + coupler_margin), 369 ] 370 elif col1 == col2: # vertical 371 row1, row2 = min(row1, row2), max(row1, row2) 372 row_center = (row1 + row2) / 2.0 373 polygon = [ 374 (col1, row1 + coupler_margin), 375 (col1 + cwidth - coupler_margin, row_center - setback), 376 (col1 + cwidth - coupler_margin, row_center + setback), 377 (col2, row2 - coupler_margin), 378 (col1 - cwidth + coupler_margin, row_center + setback), 379 (col1 - cwidth + coupler_margin, row_center - setback), 380 ] 381 382 return (polygon, Point((col1 + col2) / 2.0, (row1 + row2) / 2.0)) 383 384 def plot( 385 self, ax: Optional[plt.Axes] = None, **kwargs: Any 386 ) -> Tuple[plt.Axes, mpl_collections.Collection]: 387 """Plots the heatmap on the given Axes. 388 Args: 389 ax: the Axes to plot on. If not given, a new figure is created, 390 plotted on, and shown. 391 kwargs: The optional keyword arguments are used to temporarily 392 override the values present in the heatmap config. See 393 __init__ for more details on the allowed arguments. 394 Returns: 395 A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that 396 is plotted on. ``collection`` is the collection of paths drawn and filled. 397 """ 398 show_plot = not ax 399 if not ax: 400 fig, ax = plt.subplots(figsize=(8, 8)) 401 original_config = copy.deepcopy(self._config) 402 self.update_config(**kwargs) 403 qubits = set([q for qubits in self._value_map.keys() for q in qubits]) 404 Heatmap({q: 0.0 for q in qubits}).plot( 405 ax=ax, 406 collection_options={ 407 'cmap': 'binary', 408 'linewidths': 2, 409 'edgecolor': 'lightgrey', 410 'linestyle': 'dashed', 411 }, 412 plot_colorbar=False, 413 annotation_format=None, 414 ) 415 collection = self._plot_on_axis(ax) 416 if show_plot: 417 fig.show() 418 self._config = original_config 419 return (ax, collection) 420