1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import copy
15from dataclasses import astuple, dataclass
16from typing import (
17    Any,
18    Dict,
19    List,
20    Mapping,
21    Optional,
22    overload,
23    Sequence,
24    SupportsFloat,
25    Tuple,
26    Union,
27)
28
29import matplotlib as mpl
30import matplotlib.collections as mpl_collections
31import matplotlib.pyplot as plt
32import numpy as np
33from mpl_toolkits import axes_grid1
34
35from cirq.devices import grid_qubit
36from cirq.vis import vis_utils
37
38QubitTuple = Tuple[grid_qubit.GridQubit, ...]
39
40Polygon = Sequence[Tuple[float, float]]
41
42
43@dataclass
44class Point:
45    x: float
46    y: float
47
48    def __iter__(self):
49        return iter(astuple(self))
50
51
52@dataclass
53class PolygonUnit:
54    """Dataclass to store information about a single polygon unit to plot on the heatmap
55
56    For single (grid) qubit heatmaps, the polygon is a square.
57    For two (grid) qubit interaction heatmaps, the polygon is a hexagon.
58
59    Args:
60        polygon: Vertices of the polygon to plot.
61        value: The value for the heatmap coloring.
62        center: The center point of the polygon where annotation text should be printed.
63        annot: The annotation string to print on the coupler.
64
65    """
66
67    polygon: Polygon
68    value: float
69    center: Point
70    annot: Optional[str]
71
72
73class Heatmap:
74    """Distribution of a value in 2D qubit lattice as a color map."""
75
76    # pylint: disable=function-redefined
77    @overload
78    def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs):
79        pass
80
81    @overload
82    def __init__(self, value_map: Mapping[grid_qubit.GridQubit, SupportsFloat], **kwargs):
83        pass
84
85    # TODO(#3388) Add documentation for Args.
86    # pylint: disable=missing-param-doc
87    def __init__(
88        self,
89        value_map: Union[
90            Mapping[QubitTuple, SupportsFloat], Mapping[grid_qubit.GridQubit, SupportsFloat]
91        ],
92        **kwargs,
93    ):
94        """2D qubit grid Heatmaps
95
96        Draw 2D qubit grid heatmap with Matplotlib with parameters to configure the properties of
97        the plot.
98
99        Args:
100            value_map: dictionary
101                A dictionary of qubits or QubitTuples as keys and corresponding magnitude as float
102                values. It corresponds to the data which should be plotted as a heatmap.
103
104            title: str, default = None
105            plot_colorbar: bool, default = True
106
107            annotation_map: dictionary,
108                A dictionary of QubitTuples as keys and corresponding annotation str as values. It
109                corresponds to the text that should be added on top of each heatmap polygon unit.
110            annotation_format: str, default = '.2g'
111                Formatting string using which annotation_map will be implicitly contstructed by
112                applying format(value, annotation_format) for each key in value_map.
113                This is ignored if annotation_map is explicitly specified.
114            annotation_text_kwargs: Matplotlib Text **kwargs,
115
116            colorbar_position: {'right', 'left', 'top', 'bottom'}, default = 'right'
117            colorbar_size: str, default = '5%'
118            colorbar_pad: str, default = '2%'
119            colorbar_options: Matplotlib colorbar **kwargs, default = None,
120
121
122            collection_options: Matplotlib PolyCollection **kwargs, default
123                                {"cmap" : "viridis"}
124            vmin, vmax: colormap scaling floats, default = None
125        """
126        self._value_map: Mapping[QubitTuple, SupportsFloat] = {
127            k if isinstance(k, tuple) else (k,): v for k, v in value_map.items()
128        }
129        self._validate_kwargs(kwargs)
130        if '_config' not in self.__dict__:
131            self._config: Dict[str, Any] = {}
132        self._config.update(
133            {
134                "plot_colorbar": True,
135                "colorbar_position": "right",
136                "colorbar_size": "5%",
137                "colorbar_pad": "2%",
138                "collection_options": {"cmap": "viridis"},
139                "annotation_format": ".2g",
140            }
141        )
142        self._config.update(kwargs)
143
144    # pylint: enable=function-redefined,missing-param-doc
145    def _extra_valid_kwargs(self) -> List[str]:
146        return []
147
148    def _validate_kwargs(self, kwargs) -> None:
149        valid_colorbar_kwargs = [
150            "plot_colorbar",
151            "colorbar_position",
152            "colorbar_size",
153            "colorbar_pad",
154            "colorbar_options",
155        ]
156        valid_collection_kwargs = [
157            "collection_options",
158            "vmin",
159            "vmax",
160        ]
161        valid_heatmap_kwargs = [
162            "title",
163            "annotation_map",
164            "annotation_text_kwargs",
165            "annotation_format",
166        ]
167        valid_kwargs = (
168            valid_colorbar_kwargs
169            + valid_collection_kwargs
170            + valid_heatmap_kwargs
171            + self._extra_valid_kwargs()
172        )
173        if any([k not in valid_kwargs for k in kwargs]):
174            invalid_args = ", ".join([k for k in kwargs if k not in valid_kwargs])
175            raise ValueError(f"Received invalid argument(s): {invalid_args}")
176
177    def update_config(self, **kwargs) -> 'Heatmap':
178        """Add/Modify **kwargs args passed during initialisation."""
179        self._validate_kwargs(kwargs)
180        self._config.update(kwargs)
181        return self
182
183    def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]:
184        qubit = qubits[0]
185        x, y = float(qubit.row), float(qubit.col)
186        return (
187            [
188                (y - 0.5, x - 0.5),
189                (y - 0.5, x + 0.5),
190                (y + 0.5, x + 0.5),
191                (y + 0.5, x - 0.5),
192            ],
193            Point(y, x),
194        )
195
196    def _get_annotation_value(self, key, value) -> Optional[str]:
197        if self._config.get('annotation_map'):
198            return self._config['annotation_map'].get(key)
199        elif self._config.get('annotation_format'):
200            try:
201                return format(value, self._config['annotation_format'])
202            except:
203                return format(float(value), self._config['annotation_format'])
204        else:
205            return None
206
207    def _get_polygon_units(self) -> List[PolygonUnit]:
208        polygon_unit_list: List[PolygonUnit] = []
209        for qubits, value in sorted(self._value_map.items()):
210            polygon, center = self._qubits_to_polygon(qubits)
211            polygon_unit_list.append(
212                PolygonUnit(
213                    polygon=polygon,
214                    center=center,
215                    value=float(value),
216                    annot=self._get_annotation_value(qubits, value),
217                )
218            )
219        return polygon_unit_list
220
221    def _plot_colorbar(
222        self, mappable: mpl.cm.ScalarMappable, ax: plt.Axes
223    ) -> mpl.colorbar.Colorbar:
224        """Plots the colorbar. Internal."""
225        colorbar_ax = axes_grid1.make_axes_locatable(ax).append_axes(
226            position=self._config['colorbar_position'],
227            size=self._config['colorbar_size'],
228            pad=self._config['colorbar_pad'],
229        )
230        position = self._config['colorbar_position']
231        orien = 'vertical' if position in ('left', 'right') else 'horizontal'
232        colorbar = ax.figure.colorbar(
233            mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {})
234        )
235        colorbar_ax.tick_params(axis='y', direction='out')
236        return colorbar
237
238    def _write_annotations(
239        self,
240        centers_and_annot: List[Tuple[Point, Optional[str]]],
241        collection: mpl_collections.Collection,
242        ax: plt.Axes,
243    ) -> None:
244        """Writes annotations to the center of cells. Internal."""
245        for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()):
246            # Calculate the center of the cell, assuming that it is a square
247            # centered at (x=col, y=row).
248            if not annotation:
249                continue
250            x, y = center
251            face_luminance = vis_utils.relative_luminance(facecolor)
252            text_color = 'black' if face_luminance > 0.4 else 'white'
253            text_kwargs = dict(color=text_color, ha="center", va="center")
254            text_kwargs.update(self._config.get('annotation_text_kwargs', {}))
255            ax.text(x, y, annotation, **text_kwargs)
256
257    def _plot_on_axis(self, ax: plt.Axes) -> mpl_collections.Collection:
258        # Step-1: Convert value_map to a list of polygons to plot.
259        polygon_list = self._get_polygon_units()
260        collection: mpl_collections.Collection = mpl_collections.PolyCollection(
261            [c.polygon for c in polygon_list],
262            **self._config.get('collection_options', {}),
263        )
264        collection.set_clim(self._config.get('vmin'), self._config.get('vmax'))
265        collection.set_array(np.array([c.value for c in polygon_list]))
266        # Step-2: Plot the polygons
267        ax.add_collection(collection)
268        collection.update_scalarmappable()
269        # Step-3: Write annotation texts
270        if self._config.get('annotation_map') or self._config.get('annotation_format'):
271            self._write_annotations([(c.center, c.annot) for c in polygon_list], collection, ax)
272        ax.set(xlabel='column', ylabel='row')
273        # Step-4: Draw colorbar if applicable
274        if self._config.get('plot_colorbar'):
275            self._plot_colorbar(collection, ax)
276        # Step-5: Set min/max limits of x/y axis on the plot.
277        rows = set([q.row for qubits in self._value_map.keys() for q in qubits])
278        cols = set([q.col for qubits in self._value_map.keys() for q in qubits])
279        min_row, max_row = min(rows), max(rows)
280        min_col, max_col = min(cols), max(cols)
281        min_xtick = np.floor(min_col)
282        max_xtick = np.ceil(max_col)
283        ax.set_xticks(np.arange(min_xtick, max_xtick + 1))
284        min_ytick = np.floor(min_row)
285        max_ytick = np.ceil(max_row)
286        ax.set_yticks(np.arange(min_ytick, max_ytick + 1))
287        ax.set_xlim((min_xtick - 0.6, max_xtick + 0.6))
288        ax.set_ylim((max_ytick + 0.6, min_ytick - 0.6))
289        # Step-6: Set title
290        if self._config.get("title"):
291            ax.set_title(self._config["title"], fontweight='bold')
292        return collection
293
294    def plot(
295        self, ax: Optional[plt.Axes] = None, **kwargs: Any
296    ) -> Tuple[plt.Axes, mpl_collections.Collection]:
297        """Plots the heatmap on the given Axes.
298        Args:
299            ax: the Axes to plot on. If not given, a new figure is created,
300                plotted on, and shown.
301            kwargs: The optional keyword arguments are used to temporarily
302                override the values present in the heatmap config. See
303                __init__ for more details on the allowed arguments.
304        Returns:
305            A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that
306            is plotted on. ``collection`` is the collection of paths drawn and filled.
307        """
308        show_plot = not ax
309        if not ax:
310            fig, ax = plt.subplots(figsize=(8, 8))
311        original_config = copy.deepcopy(self._config)
312        self.update_config(**kwargs)
313        collection = self._plot_on_axis(ax)
314        if show_plot:
315            fig.show()
316        self._config = original_config
317        return (ax, collection)
318
319
320class TwoQubitInteractionHeatmap(Heatmap):
321    """Visualizing interactions between neighboring qubits on a 2D grid."""
322
323    # TODO(#3388) Add documentation for Args.
324    # pylint: disable=missing-param-doc
325    def __init__(self, value_map: Mapping[QubitTuple, SupportsFloat], **kwargs):
326        """Heatmap to display two-qubit interaction fidelities.
327
328        Draw 2D qubit-qubit interaction heatmap with Matplotlib with arguments to configure the
329        properties of the plot. The valid argument list includes all arguments of cirq.vis.Heatmap()
330        plus the following.
331
332        Args:
333            coupler_margin: float, default = 0.03
334            coupler_width: float, default = 0.6
335        """
336        self._config: Dict[str, Any] = {
337            "coupler_margin": 0.03,
338            "coupler_width": 0.6,
339        }
340        super().__init__(value_map, **kwargs)
341
342    # pylint: enable=missing-param-doc
343    def _extra_valid_kwargs(self) -> List[str]:
344        return ["coupler_margin", "coupler_width"]
345
346    def _qubits_to_polygon(self, qubits: QubitTuple) -> Tuple[Polygon, Point]:
347        coupler_margin = self._config["coupler_margin"]
348        coupler_width = self._config["coupler_width"]
349        cwidth = coupler_width / 2.0
350        setback = 0.5 - cwidth
351        row1, col1 = map(float, (qubits[0].row, qubits[0].col))
352        row2, col2 = map(float, (qubits[1].row, qubits[1].col))
353        if abs(row1 - row2) + abs(col1 - col2) != 1:
354            raise ValueError(
355                f"{qubits[0]}-{qubits[1]} is not supported because they are not nearest neighbors"
356            )
357        if coupler_width <= 0:
358            polygon: Polygon = []
359        elif row1 == row2:  # horizontal
360            col1, col2 = min(col1, col2), max(col1, col2)
361            col_center = (col1 + col2) / 2.0
362            polygon = [
363                (col1 + coupler_margin, row1),
364                (col_center - setback, row1 + cwidth - coupler_margin),
365                (col_center + setback, row1 + cwidth - coupler_margin),
366                (col2 - coupler_margin, row2),
367                (col_center + setback, row1 - cwidth + coupler_margin),
368                (col_center - setback, row1 - cwidth + coupler_margin),
369            ]
370        elif col1 == col2:  # vertical
371            row1, row2 = min(row1, row2), max(row1, row2)
372            row_center = (row1 + row2) / 2.0
373            polygon = [
374                (col1, row1 + coupler_margin),
375                (col1 + cwidth - coupler_margin, row_center - setback),
376                (col1 + cwidth - coupler_margin, row_center + setback),
377                (col2, row2 - coupler_margin),
378                (col1 - cwidth + coupler_margin, row_center + setback),
379                (col1 - cwidth + coupler_margin, row_center - setback),
380            ]
381
382        return (polygon, Point((col1 + col2) / 2.0, (row1 + row2) / 2.0))
383
384    def plot(
385        self, ax: Optional[plt.Axes] = None, **kwargs: Any
386    ) -> Tuple[plt.Axes, mpl_collections.Collection]:
387        """Plots the heatmap on the given Axes.
388        Args:
389            ax: the Axes to plot on. If not given, a new figure is created,
390                plotted on, and shown.
391            kwargs: The optional keyword arguments are used to temporarily
392                override the values present in the heatmap config. See
393                __init__ for more details on the allowed arguments.
394        Returns:
395            A 2-tuple ``(ax, collection)``. ``ax`` is the `plt.Axes` that
396            is plotted on. ``collection`` is the collection of paths drawn and filled.
397        """
398        show_plot = not ax
399        if not ax:
400            fig, ax = plt.subplots(figsize=(8, 8))
401        original_config = copy.deepcopy(self._config)
402        self.update_config(**kwargs)
403        qubits = set([q for qubits in self._value_map.keys() for q in qubits])
404        Heatmap({q: 0.0 for q in qubits}).plot(
405            ax=ax,
406            collection_options={
407                'cmap': 'binary',
408                'linewidths': 2,
409                'edgecolor': 'lightgrey',
410                'linestyle': 'dashed',
411            },
412            plot_colorbar=False,
413            annotation_format=None,
414        )
415        collection = self._plot_on_axis(ax)
416        if show_plot:
417            fig.show()
418        self._config = original_config
419        return (ax, collection)
420