14import copy
15from dataclasses import astuple, dataclass
16from typing import (
17    Any,
18    Dict,
19    List,
20    Mapping,
21    Optional,
22    overload,
23    Sequence,
24    SupportsFloat,
25    Tuple,
26    Union,
29import matplotlib as mpl
30import matplotlib.collections as mpl_collections
31import matplotlib.pyplot as plt
32import numpy as np
33from mpl_toolkits import axes_grid1
35from cirq.devices import grid_qubit
36from cirq.vis import vis_utils
38QubitTuple = Tuple[grid_qubit.GridQubit, ...]
40Polygon = Sequence[Tuple[float, float]]
44class Point:
45    x: float
46    y: float
48    def __iter__(self):
49        return iter(astuple(self))
53class PolygonUnit:
54    """Dataclass to store information about a single polygon unit to plot on the heatmap
56    For single (grid) qubit heatmaps, the polygon is a square.
57    For two (grid) qubit interaction heatmaps, the polygon is a hexagon.
59    Args:
60        polygon: Vertices of the polygon to plot.
61        value: The value for the heatmap coloring.
62        center: The center point of the polygon where annotation text should be printed.
63        annot: The annotation string to print on the coupler.
65    """
67    polygon: Polygon
68    value: float
69    center: Point
70    annot: Optional[str]
73class Heatmap:
74    """Distribution of a value in 2D qubit lattice as a color map."""
76    # pylint: disable=function-redefined
77    @overload
78    def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs):
79        pass
81    @overload
82    def __init__(self, value_map: Mapping[grid_qubit.GridQubit, SupportsFloat], **kwargs):
83        pass
85    # TODO(#3388) Add documentation for Args.
86    # pylint: disable=missing-param-doc
87    def __init__(
88        self,
89        value_map: Union[
90            Mapping[QubitTuple, SupportsFloat], Mapping[grid_qubit.GridQubit, SupportsFloat]
91        ],
92        **kwargs,
93    ):
94        """2D qubit grid Heatmaps
96        Draw 2D qubit grid heatmap with Matplotlib with parameters to configure the properties of
97        the plot.
99        Args:
100            value_map: dictionary
101                A dictionary of qubits or QubitTuples as keys and corresponding magnitude as float
102                values. It corresponds to the data which should be plotted as a heatmap.
104            title: str, default = None
105            plot_colorbar: bool, default = True
107            annotation_map: dictionary,
108                A dictionary of QubitTuples as keys and corresponding annotation str as values. It
109                corresponds to the text that should be added on top of each heatmap polygon unit.
110            annotation_format: str, default = '.2g'
111                Formatting string using which annotation_map will be implicitly contstructed by
112                applying format(value, annotation_format) for each key in value_map.
113                This is ignored if annotation_map is explicitly specified.
114            annotation_text_kwargs: Matplotlib Text **kwargs,
116            colorbar_position: {'right', 'left', 'top', 'bottom'}, default = 'right'
117            colorbar_size: str, default = '5%'
118            colorbar_pad: str, default = '2%'
119            colorbar_options: Matplotlib colorbar **kwargs, default = None,
122            collection_options: Matplotlib PolyCollection **kwargs, default
123                                {"cmap" : "viridis"}
124            vmin, vmax: colormap scaling floats, default = None
125        """
126        self._value_map: Mapping[QubitTuple, SupportsFloat] = {
127            k if isinstance(k, tuple) else (k,): v for k, v in value_map.items()
128        }
129        self._validate_kwargs(kwargs)
130        if '_config' not in self.__dict__:
131            self._config: Dict[str, Any] = {}
132        self._config.update(
133            {
134                "plot_colorbar": True,
135                "colorbar_position": "right",
136                "colorbar_size": "5%",
137                "colorbar_pad": "2%",
138                "collection_options": {"cmap": "viridis"},
139                "annotation_format": ".2g",
140            }
141        )
142        self._config.update(kwargs)
144    # pylint: enable=function-redefined,missing-param-doc
145    def _extra_valid_kwargs(self) -> List[str]:
146        return []
148    def _validate_kwargs(self, kwargs) -> None:
149        valid_colorbar_kwargs = [
150            "plot_colorbar",
151            "colorbar_position",
152            "colorbar_size",
153            "colorbar_pad",
154            "colorbar_options",
155        ]
156        valid_collection_kwargs = [
157            "collection_options",
158            "vmin",
159            "vmax",
160        ]
161        valid_heatmap_kwargs = [
162            "title",
163            "annotation_map",
164            "annotation_text_kwargs",
165            "annotation_format",
166        ]
167        valid_kwargs = (
168            valid_colorbar_kwargs
169            + valid_collection_kwargs
170            + valid_heatmap_kwargs
171            + self._extra_valid_kwargs()
172        )
173        if any([k not in valid_kwargs for k in kwargs]):
174            invalid_args = ", ".join([k for k in kwargs if k not in valid_kwargs])
175            raise ValueError(f"Received invalid argument(s): {invalid_args}")
177    def update_config(self, **kwargs) -> 'Heatmap':
178        """Add/Modify **kwargs args passed during initialisation."""
179        self._validate_kwargs(kwargs)
180        self._config.update(kwargs)
181        return self
183    def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]:
184        qubit = qubits[0]
185        x, y = float(qubit.row), float(qubit.col)
186        return (
187            [
188                (y - 0.5, x - 0.5),
189                (y - 0.5, x + 0.5),
190                (y + 0.5, x + 0.5),
191                (y + 0.5, x - 0.5),
192            ],
193            Point(y, x),
194        )
196    def _get_annotation_value(self, key, value) -> Optional[str]:
197        if self._config.get('annotation_map'):
198            return self._config['annotation_map'].get(key)
199        elif self._config.get('annotation_format'):
200            try:
201                return format(value, self._config['annotation_format'])
202            except:
203                return format(float(value), self._config['annotation_format'])
204        else:
205            return None
207    def _get_polygon_units(self) -> List[PolygonUnit]:
208        polygon_unit_list: List[PolygonUnit] = []
209        for qubits, value in sorted(self._value_map.items()):
210            polygon, center = self._qubits_to_polygon(qubits)
211            polygon_unit_list.append(
212                PolygonUnit(
213                    polygon=polygon,
214                    center=center,
215                    value=float(value),
216                    annot=self._get_annotation_value(qubits, value),
217                )
218            )
219        return polygon_unit_list
221    def _plot_colorbar(
222        self, mappable: mpl.cm.ScalarMappable, ax: plt.Axes
223    ) -> mpl.colorbar.Colorbar:
224        """Plots the colorbar. Internal."""
225        colorbar_ax = axes_grid1.make_axes_locatable(ax).append_axes(
226            position=self._config['colorbar_position'],
227            size=self._config['colorbar_size'],
228            pad=self._config['colorbar_pad'],
229        )
230        position = self._config['colorbar_position']
231        orien = 'vertical' if position in ('left', 'right') else 'horizontal'
232        colorbar = ax.figure.colorbar(
233            mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {})
234        )
235        colorbar_ax.tick_params(axis='y', direction='out')
236        return colorbar
238    def _write_annotations(
239        self,
240        centers_and_annot: List[Tuple[Point, Optional[str]]],
241        collection: mpl_collections.Collection,
242        ax: plt.Axes,
243    ) -> None:
244        """Writes annotations to the center of cells. Internal."""
245        for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()):
246            # Calculate the center of the cell, assuming that it is a square
247            # centered at (x=col, y=row).
248            if not annotation:
249                continue
250            x, y = center
251            face_luminance = vis_utils.relative_luminance(facecolor)
252            text_color = 'black' if face_luminance > 0.4 else 'white'
253            text_kwargs = dict(color=text_color, ha="center", va="center")
254            text_kwargs.update(self._config.get('annotation_text_kwargs', {}))
255            ax.text(x, y, annotation, **text_kwargs)
257    def _plot_on_axis(self, ax: plt.Axes) -> mpl_collections.Collection:
258        # Step-1: Convert value_map to a list of polygons to plot.
259        polygon_list = self._get_polygon_units()
260        collection: mpl_collections.Collection = mpl_collections.PolyCollection(
261            [c.polygon for c in polygon_list],
262            **self._config.get('collection_options', {}),
263        )
264        collection.set_clim(self._config.get('vmin'), self._config.get('vmax'))
265        collection.set_array(np.array([c.value for c in polygon_list]))
266        # Step-2: Plot the polygons
267        ax.add_collection(collection)
268        collection.update_scalarmappable()
269        # Step-3: Write annotation texts
270        if self._config.get('annotation_map') or self._config.get('annotation_format'):
271            self._write_annotations([(c.center, c.annot) for c in polygon_list], collection, ax)
272        ax.set(xlabel='column', ylabel='row')
273        # Step-4: Draw colorbar if applicable
274        if self._config.get('plot_colorbar'):
275            self._plot_colorbar(collection, ax)
276        # Step-5: Set min/max limits of x/y axis on the plot.
277        rows = set([q.row for qubits in self._value_map.keys() for q in qubits])
278        cols = set([q.col for qubits in self._value_map.keys() for q in qubits])
279        min_row, max_row = min(rows), max(rows)
280        min_col, max_col = min(cols), max(cols)
281        min_xtick = np.floor(min_col)
282        max_xtick = np.ceil(max_col)
283        ax.set_xticks(np.arange(min_xtick, max_xtick + 1))
284        min_ytick = np.floor(min_row)
285        max_ytick = np.ceil(max_row)
286        ax.set_yticks(np.arange(min_ytick, max_ytick + 1))
287        ax.set_xlim((min_xtick - 0.6, max_xtick + 0.6))
288        ax.set_ylim((max_ytick + 0.6, min_ytick - 0.6))
289        # Step-6: Set title
290        if self._config.get("title"):
291            ax.set_title(self._config["title"], fontweight='bold')
292        return collection
294    def plot(
295        self, ax: Optional[plt.Axes] = None, **kwargs: Any
296    ) -> Tuple[plt.Axes, mpl_collections.Collection]:
297        """Plots the heatmap on the given Axes.
298        Args:
299            ax: the Axes to plot on. If not given, a new figure is created,
300                plotted on, and shown.
301            kwargs: The optional keyword arguments are used to temporarily
302                override the values present in the heatmap config. See
303                __init__ for more details on the allowed arguments.
304        Returns:
305            A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that
306            is plotted on. ``collection`` is the collection of paths drawn and filled.
307        """
308        show_plot = not ax
309        if not ax:
310            fig, ax = plt.subplots(figsize=(8, 8))
311        original_config = copy.deepcopy(self._config)
312        self.update_config(**kwargs)
313        collection = self._plot_on_axis(ax)
314        if show_plot:
315            fig.show()
316        self._config = original_config
317        return (ax, collection)
320class TwoQubitInteractionHeatmap(Heatmap):
321    """Visualizing interactions between neighboring qubits on a 2D grid."""
323    # TODO(#3388) Add documentation for Args.
324    # pylint: disable=missing-param-doc
325    def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs):
326        """Heatmap to display two-qubit interaction fidelities.
328        Draw 2D qubit-qubit interaction heatmap with Matplotlib with arguments to configure the
329        properties of the plot. The valid argument list includes all arguments of cirq.vis.Heatmap()
330        plus the following.
332        Args:
333            coupler_margin: float, default = 0.03
334            coupler_width: float, default = 0.6
335        """
336        self._config: Dict[str, Any] = {
337            "coupler_margin": 0.03,
338            "coupler_width": 0.6,
339        }
340        super().__init__(value_map, **kwargs)
342    # pylint: enable=missing-param-doc
343    def _extra_valid_kwargs(self) -> List[str]:
344        return ["coupler_margin", "coupler_width"]
346    def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]:
347        coupler_margin = self._config["coupler_margin"]
348        coupler_width = self._config["coupler_width"]
349        cwidth = coupler_width / 2.0
350        setback = 0.5 - cwidth
351        row1, col1 = map(float, (qubits[0].row, qubits[0].col))
352        row2, col2 = map(float, (qubits[1].row, qubits[1].col))
353        if abs(row1 - row2) + abs(col1 - col2) != 1:
354            raise ValueError(
355                f"{qubits[0]}-{qubits[1]} is not supported because they are not nearest neighbors"
356            )
357        if coupler_width <= 0:
358            polygon: Polygon = []
359        elif row1 == row2:  # horizontal
360            col1, col2 = min(col1, col2), max(col1, col2)
361            col_center = (col1 + col2) / 2.0
362            polygon = [
363                (col1 + coupler_margin, row1),
364                (col_center - setback, row1 + cwidth - coupler_margin),
365                (col_center + setback, row1 + cwidth - coupler_margin),
366                (col2 - coupler_margin, row2),
367                (col_center + setback, row1 - cwidth + coupler_margin),
368                (col_center - setback, row1 - cwidth + coupler_margin),
369            ]
370        elif col1 == col2:  # vertical
371            row1, row2 = min(row1, row2), max(row1, row2)
372            row_center = (row1 + row2) / 2.0
373            polygon = [
374                (col1, row1 + coupler_margin),
375                (col1 + cwidth - coupler_margin, row_center - setback),
376                (col1 + cwidth - coupler_margin, row_center + setback),
377                (col2, row2 - coupler_margin),
378                (col1 - cwidth + coupler_margin, row_center + setback),
379                (col1 - cwidth + coupler_margin, row_center - setback),
380            ]
382        return (polygon, Point((col1 + col2) / 2.0, (row1 + row2) / 2.0))
384    def plot(
385        self, ax: Optional[plt.Axes] = None, **kwargs: Any
386    ) -> Tuple[plt.Axes, mpl_collections.Collection]:
387        """Plots the heatmap on the given Axes.
388        Args:
389            ax: the Axes to plot on. If not given, a new figure is created,
390                plotted on, and shown.
391            kwargs: The optional keyword arguments are used to temporarily
392                override the values present in the heatmap config. See
393                __init__ for more details on the allowed arguments.
394        Returns:
395            A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that
396            is plotted on. ``collection`` is the collection of paths drawn and filled.
397        """
398        show_plot = not ax
399        if not ax:
400            fig, ax = plt.subplots(figsize=(8, 8))
401        original_config = copy.deepcopy(self._config)
402        self.update_config(**kwargs)
403        qubits = set([q for qubits in self._value_map.keys() for q in qubits])
404        Heatmap({q: 0.0 for q in qubits}).plot(
405            ax=ax,
406            collection_options={
407                'cmap': 'binary',
408                'linewidths': 2,
409                'edgecolor': 'lightgrey',
410                'linestyle': 'dashed',
411            },
412            plot_colorbar=False,
413            annotation_format=None,
414        )
415        collection = self._plot_on_axis(ax)
416        if show_plot:
417            fig.show()
418        self._config = original_config
419        return (ax, collection)