1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for Heatmap."""
15
16import pathlib
17import string
18from tempfile import mkdtemp
19
20import numpy as np
21import pytest
22
23import matplotlib as mpl
24import matplotlib.pyplot as plt
25
26from cirq.devices import grid_qubit
27from cirq.vis import heatmap
28
29
30@pytest.fixture
31def ax():
32    figure = mpl.figure.Figure()
33    return figure.add_subplot(111)
34
35
36@pytest.mark.parametrize('tuple_keys', [True, False])
37def test_cells_positions(ax, tuple_keys):
38    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
39    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
40    values = np.random.random(len(qubits))
41    test_value_map = {
42        (qubit,) if tuple_keys else qubit: value for qubit, value in zip(qubits, values)
43    }
44    _, collection = heatmap.Heatmap(test_value_map).plot(ax)
45
46    found_qubits = set()
47    for path in collection.get_paths():
48        vertices = path.vertices[0:4]
49        row = int(round(np.mean([v[1] for v in vertices])))
50        col = int(round(np.mean([v[0] for v in vertices])))
51        found_qubits.add((row, col))
52    assert found_qubits == set(row_col_list)
53
54
55def test_two_qubit_heatmap(ax):
56    value_map = {
57        (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768,
58        (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835,
59    }
60    title = "Two Qubit Interaction Heatmap"
61    heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax)
62    assert ax.get_title() == title
63
64
65def test_invalid_args():
66    value_map = {
67        (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768,
68        (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835,
69    }
70    with pytest.raises(ValueError, match="invalid argument.*colormap"):
71        heatmap.TwoQubitInteractionHeatmap(value_map, colormap='Greys')
72
73
74def test_two_qubit_nearest_neighbor(ax):
75    value_map = {
76        (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768,
77        (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(3, 2)): 0.0076079162393482835,
78    }
79    with pytest.raises(ValueError, match="not nearest neighbors"):
80        heatmap.TwoQubitInteractionHeatmap(value_map, coupler_width=0).plot(ax)
81
82
83# Test colormaps are the first one in each category in
84# https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html.
85@pytest.mark.parametrize(
86    'colormap_name', ['viridis', 'Greys', 'binary', 'PiYG', 'twilight', 'Pastel1', 'flag']
87)
88def test_cell_colors(ax, colormap_name):
89    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
90    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
91    values = 1.0 + 2.0 * np.random.random(len(qubits))  # [1, 3)
92    test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)}
93    test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)}
94    vmin, vmax = 1.5, 2.5
95    random_heatmap = heatmap.Heatmap(
96        test_value_map, collection_options={'cmap': colormap_name}, vmin=vmin, vmax=vmax
97    )
98    _, mesh = random_heatmap.plot(ax)
99
100    colormap = mpl.cm.get_cmap(colormap_name)
101    for path, facecolor in zip(mesh.get_paths(), mesh.get_facecolors()):
102        vertices = path.vertices[0:4]
103        row = int(round(np.mean([v[1] for v in vertices])))
104        col = int(round(np.mean([v[0] for v in vertices])))
105        value = test_row_col_map[(row, col)]
106        color_scale = (value - vmin) / (vmax - vmin)
107        if color_scale < 0.0:
108            color_scale = 0.0
109        if color_scale > 1.0:
110            color_scale = 1.0
111        expected_color = np.array(colormap(color_scale))
112        assert np.all(np.isclose(facecolor, expected_color))
113
114
115def test_default_annotation(ax):
116    """Tests that the default annotation is '.2g' format on float(value)."""
117    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
118    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
119    values = ['3.752', '42', '-5.27e8', '-7.34e-9', 732, 0.432, 3.9753e28]
120    test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)}
121    test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)}
122    random_heatmap = heatmap.Heatmap(test_value_map)
123    random_heatmap.plot(ax)
124    actual_texts = set()
125    for artist in ax.get_children():
126        if isinstance(artist, mpl.text.Text):
127            col, row = artist.get_position()
128            text = artist.get_text()
129            actual_texts.add(((row, col), text))
130    expected_texts = set(
131        (qubit, format(float(value), '.2g')) for qubit, value in test_row_col_map.items()
132    )
133    assert expected_texts.issubset(actual_texts)
134
135
136@pytest.mark.parametrize('format_string', ['.3e', '.2f', '.4g'])
137def test_annotation_position_and_content(ax, format_string):
138    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
139    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
140    values = np.random.random(len(qubits))
141    test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)}
142    test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)}
143    random_heatmap = heatmap.Heatmap(test_value_map, annotation_format=format_string)
144    random_heatmap.plot(ax)
145    actual_texts = set()
146    for artist in ax.get_children():
147        if isinstance(artist, mpl.text.Text):
148            col, row = artist.get_position()
149            text = artist.get_text()
150            actual_texts.add(((row, col), text))
151    expected_texts = set(
152        (qubit, format(value, format_string)) for qubit, value in test_row_col_map.items()
153    )
154    assert expected_texts.issubset(actual_texts)
155
156
157def test_annotation_map(ax):
158    row_col_list = [(0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)]
159    qubits = [grid_qubit.GridQubit(*row_col) for row_col in row_col_list]
160    values = np.random.random(len(qubits))
161    annos = np.random.choice([c for c in string.ascii_letters], len(qubits))
162    test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)}
163    test_anno_map = {
164        (qubit,): anno
165        for qubit, row_col, anno in zip(qubits, row_col_list, annos)
166        if row_col != (1, 6)
167    }
168    random_heatmap = heatmap.Heatmap(test_value_map, annotation_map=test_anno_map)
169    random_heatmap.plot(ax)
170    actual_texts = set()
171    for artist in ax.get_children():
172        if isinstance(artist, mpl.text.Text):
173            col, row = artist.get_position()
174            assert (row, col) != (1, 6)
175            actual_texts.add(((row, col), artist.get_text()))
176    expected_texts = set(
177        (row_col, anno) for row_col, anno in zip(row_col_list, annos) if row_col != (1, 6)
178    )
179    assert expected_texts.issubset(actual_texts)
180
181
182@pytest.mark.parametrize('format_string', ['.3e', '.2f', '.4g', 's'])
183def test_non_float_values(ax, format_string):
184    class Foo:
185        def __init__(self, value: float, unit: str):
186            self.value = value
187            self.unit = unit
188
189        def __float__(self):
190            return self.value
191
192        def __format__(self, format_string):
193            if format_string == 's':
194                return f'{self.value}{self.unit}'
195            else:
196                return format(self.value, format_string)
197
198    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
199    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
200    values = np.random.random(len(qubits))
201    units = np.random.choice([c for c in string.ascii_letters], len(qubits))
202    test_value_map = {
203        (qubit,): Foo(float(value), unit) for qubit, value, unit in zip(qubits, values, units)
204    }
205    row_col_map = {
206        row_col: Foo(float(value), unit)
207        for row_col, value, unit in zip(row_col_list, values, units)
208    }
209    colormap_name = 'viridis'
210    vmin, vmax = 0.0, 1.0
211    random_heatmap = heatmap.Heatmap(
212        test_value_map,
213        collection_options={'cmap': colormap_name},
214        vmin=vmin,
215        vmax=vmax,
216        annotation_format=format_string,
217    )
218
219    _, mesh = random_heatmap.plot(ax)
220
221    colormap = mpl.cm.get_cmap(colormap_name)
222    for path, facecolor in zip(mesh.get_paths(), mesh.get_facecolors()):
223        vertices = path.vertices[0:4]
224        row = int(round(np.mean([v[1] for v in vertices])))
225        col = int(round(np.mean([v[0] for v in vertices])))
226        foo = row_col_map[(row, col)]
227        color_scale = (foo.value - vmin) / (vmax - vmin)
228        expected_color = np.array(colormap(color_scale))
229        assert np.all(np.isclose(facecolor, expected_color))
230
231    for artist in ax.get_children():
232        if isinstance(artist, mpl.text.Text):
233            col, row = artist.get_position()
234            if (row, col) in test_value_map:
235                foo = test_value_map[(row, col)]
236                actual_text = artist.get_text()
237                expected_text = format(foo, format_string)
238                assert actual_text == expected_text
239
240
241@pytest.mark.parametrize(
242    'position,size,pad',
243    [
244        ('right', "5%", "2%"),
245        ('right', "5%", "10%"),
246        ('right', "20%", "2%"),
247        ('right', "20%", "10%"),
248        ('left', "5%", "2%"),
249        ('left', "5%", "10%"),
250        ('left', "20%", "2%"),
251        ('left', "20%", "10%"),
252        ('top', "5%", "2%"),
253        ('top', "5%", "10%"),
254        ('top', "20%", "2%"),
255        ('top', "20%", "10%"),
256        ('bottom', "5%", "2%"),
257        ('bottom', "5%", "10%"),
258        ('bottom', "20%", "2%"),
259        ('bottom', "20%", "10%"),
260    ],
261)
262def test_colorbar(ax, position, size, pad):
263    row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8))
264    qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list]
265    values = np.random.random(len(qubits))
266    test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)}
267    random_heatmap = heatmap.Heatmap(test_value_map, plot_colorbar=False)
268    fig1, ax1 = plt.subplots()
269    random_heatmap.plot(ax1)
270    fig2, ax2 = plt.subplots()
271    random_heatmap.plot(
272        ax2, plot_colorbar=True, colorbar_position=position, colorbar_size=size, colorbar_pad=pad
273    )
274
275    # We need to call savefig() explicitly for updating axes position since the figure
276    # object has been altered in the HeatMap._plot_colorbar function.
277    tmp_dir = mkdtemp()
278    fig2.savefig(pathlib.Path(tmp_dir) / 'tmp.png')
279
280    # Check that the figure has one more object in it when colorbar is on.
281    assert len(fig2.get_children()) == len(fig1.get_children()) + 1
282
283    fig_pos = fig2.get_axes()[0].get_position()
284    colorbar_pos = fig2.get_axes()[1].get_position()
285
286    origin_axes_size = (
287        fig_pos.xmax - fig_pos.xmin
288        if position in ["left", "right"]
289        else fig_pos.ymax - fig_pos.ymin
290    )
291    expected_pad = int(pad.replace("%", "")) / 100 * origin_axes_size
292    expected_size = int(size.replace("%", "")) / 100 * origin_axes_size
293
294    if position == "right":
295        pad_distance = colorbar_pos.xmin - fig_pos.xmax
296        colorbar_size = colorbar_pos.xmax - colorbar_pos.xmin
297    elif position == "left":
298        pad_distance = fig_pos.xmin - colorbar_pos.xmax
299        colorbar_size = colorbar_pos.xmax - colorbar_pos.xmin
300    elif position == "top":
301        pad_distance = colorbar_pos.ymin - fig_pos.ymax
302        colorbar_size = colorbar_pos.ymax - colorbar_pos.ymin
303    elif position == "bottom":
304        pad_distance = fig_pos.ymin - colorbar_pos.ymax
305        colorbar_size = colorbar_pos.ymax - colorbar_pos.ymin
306
307    assert np.isclose(colorbar_size, expected_size)
308    assert np.isclose(pad_distance, expected_pad)
309
310    plt.close(fig1)
311    plt.close(fig2)
312
313
314def test_plot_updates_local_config():
315    value_map_2d = {
316        (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768,
317        (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835,
318    }
319    value_map_1d = {
320        (grid_qubit.GridQubit(3, 2),): 0.004619111460557768,
321        (grid_qubit.GridQubit(4, 2),): 0.0076079162393482835,
322    }
323    original_title = "Two Qubit Interaction Heatmap"
324    new_title = "Temporary title for the plot"
325    for random_heatmap in [
326        heatmap.TwoQubitInteractionHeatmap(value_map_2d, title=original_title),
327        heatmap.Heatmap(value_map_1d, title=original_title),
328    ]:
329        _, ax = plt.subplots()
330        random_heatmap.plot(ax, title=new_title)
331        assert ax.get_title() == new_title
332        _, ax = plt.subplots()
333        random_heatmap.plot(ax)
334        assert ax.get_title() == original_title
335