1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for Heatmap.""" 15 16import pathlib 17import string 18from tempfile import mkdtemp 19 20import numpy as np 21import pytest 22 23import matplotlib as mpl 24import matplotlib.pyplot as plt 25 26from cirq.devices import grid_qubit 27from cirq.vis import heatmap 28 29 30@pytest.fixture 31def ax(): 32 figure = mpl.figure.Figure() 33 return figure.add_subplot(111) 34 35 36@pytest.mark.parametrize('tuple_keys', [True, False]) 37def test_cells_positions(ax, tuple_keys): 38 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 39 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 40 values = np.random.random(len(qubits)) 41 test_value_map = { 42 (qubit,) if tuple_keys else qubit: value for qubit, value in zip(qubits, values) 43 } 44 _, collection = heatmap.Heatmap(test_value_map).plot(ax) 45 46 found_qubits = set() 47 for path in collection.get_paths(): 48 vertices = path.vertices[0:4] 49 row = int(round(np.mean([v[1] for v in vertices]))) 50 col = int(round(np.mean([v[0] for v in vertices]))) 51 found_qubits.add((row, col)) 52 assert found_qubits == set(row_col_list) 53 54 55def test_two_qubit_heatmap(ax): 56 value_map = { 57 (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768, 58 (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835, 59 } 60 title = "Two Qubit Interaction Heatmap" 61 heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax) 62 assert ax.get_title() == title 63 64 65def test_invalid_args(): 66 value_map = { 67 (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768, 68 (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835, 69 } 70 with pytest.raises(ValueError, match="invalid argument.*colormap"): 71 heatmap.TwoQubitInteractionHeatmap(value_map, colormap='Greys') 72 73 74def test_two_qubit_nearest_neighbor(ax): 75 value_map = { 76 (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768, 77 (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(3, 2)): 0.0076079162393482835, 78 } 79 with pytest.raises(ValueError, match="not nearest neighbors"): 80 heatmap.TwoQubitInteractionHeatmap(value_map, coupler_width=0).plot(ax) 81 82 83# Test colormaps are the first one in each category in 84# https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html. 85@pytest.mark.parametrize( 86 'colormap_name', ['viridis', 'Greys', 'binary', 'PiYG', 'twilight', 'Pastel1', 'flag'] 87) 88def test_cell_colors(ax, colormap_name): 89 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 90 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 91 values = 1.0 + 2.0 * np.random.random(len(qubits)) # [1, 3) 92 test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)} 93 test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)} 94 vmin, vmax = 1.5, 2.5 95 random_heatmap = heatmap.Heatmap( 96 test_value_map, collection_options={'cmap': colormap_name}, vmin=vmin, vmax=vmax 97 ) 98 _, mesh = random_heatmap.plot(ax) 99 100 colormap = mpl.cm.get_cmap(colormap_name) 101 for path, facecolor in zip(mesh.get_paths(), mesh.get_facecolors()): 102 vertices = path.vertices[0:4] 103 row = int(round(np.mean([v[1] for v in vertices]))) 104 col = int(round(np.mean([v[0] for v in vertices]))) 105 value = test_row_col_map[(row, col)] 106 color_scale = (value - vmin) / (vmax - vmin) 107 if color_scale < 0.0: 108 color_scale = 0.0 109 if color_scale > 1.0: 110 color_scale = 1.0 111 expected_color = np.array(colormap(color_scale)) 112 assert np.all(np.isclose(facecolor, expected_color)) 113 114 115def test_default_annotation(ax): 116 """Tests that the default annotation is '.2g' format on float(value).""" 117 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 118 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 119 values = ['3.752', '42', '-5.27e8', '-7.34e-9', 732, 0.432, 3.9753e28] 120 test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)} 121 test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)} 122 random_heatmap = heatmap.Heatmap(test_value_map) 123 random_heatmap.plot(ax) 124 actual_texts = set() 125 for artist in ax.get_children(): 126 if isinstance(artist, mpl.text.Text): 127 col, row = artist.get_position() 128 text = artist.get_text() 129 actual_texts.add(((row, col), text)) 130 expected_texts = set( 131 (qubit, format(float(value), '.2g')) for qubit, value in test_row_col_map.items() 132 ) 133 assert expected_texts.issubset(actual_texts) 134 135 136@pytest.mark.parametrize('format_string', ['.3e', '.2f', '.4g']) 137def test_annotation_position_and_content(ax, format_string): 138 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 139 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 140 values = np.random.random(len(qubits)) 141 test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)} 142 test_row_col_map = {rc: value for rc, value in zip(row_col_list, values)} 143 random_heatmap = heatmap.Heatmap(test_value_map, annotation_format=format_string) 144 random_heatmap.plot(ax) 145 actual_texts = set() 146 for artist in ax.get_children(): 147 if isinstance(artist, mpl.text.Text): 148 col, row = artist.get_position() 149 text = artist.get_text() 150 actual_texts.add(((row, col), text)) 151 expected_texts = set( 152 (qubit, format(value, format_string)) for qubit, value in test_row_col_map.items() 153 ) 154 assert expected_texts.issubset(actual_texts) 155 156 157def test_annotation_map(ax): 158 row_col_list = [(0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)] 159 qubits = [grid_qubit.GridQubit(*row_col) for row_col in row_col_list] 160 values = np.random.random(len(qubits)) 161 annos = np.random.choice([c for c in string.ascii_letters], len(qubits)) 162 test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)} 163 test_anno_map = { 164 (qubit,): anno 165 for qubit, row_col, anno in zip(qubits, row_col_list, annos) 166 if row_col != (1, 6) 167 } 168 random_heatmap = heatmap.Heatmap(test_value_map, annotation_map=test_anno_map) 169 random_heatmap.plot(ax) 170 actual_texts = set() 171 for artist in ax.get_children(): 172 if isinstance(artist, mpl.text.Text): 173 col, row = artist.get_position() 174 assert (row, col) != (1, 6) 175 actual_texts.add(((row, col), artist.get_text())) 176 expected_texts = set( 177 (row_col, anno) for row_col, anno in zip(row_col_list, annos) if row_col != (1, 6) 178 ) 179 assert expected_texts.issubset(actual_texts) 180 181 182@pytest.mark.parametrize('format_string', ['.3e', '.2f', '.4g', 's']) 183def test_non_float_values(ax, format_string): 184 class Foo: 185 def __init__(self, value: float, unit: str): 186 self.value = value 187 self.unit = unit 188 189 def __float__(self): 190 return self.value 191 192 def __format__(self, format_string): 193 if format_string == 's': 194 return f'{self.value}{self.unit}' 195 else: 196 return format(self.value, format_string) 197 198 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 199 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 200 values = np.random.random(len(qubits)) 201 units = np.random.choice([c for c in string.ascii_letters], len(qubits)) 202 test_value_map = { 203 (qubit,): Foo(float(value), unit) for qubit, value, unit in zip(qubits, values, units) 204 } 205 row_col_map = { 206 row_col: Foo(float(value), unit) 207 for row_col, value, unit in zip(row_col_list, values, units) 208 } 209 colormap_name = 'viridis' 210 vmin, vmax = 0.0, 1.0 211 random_heatmap = heatmap.Heatmap( 212 test_value_map, 213 collection_options={'cmap': colormap_name}, 214 vmin=vmin, 215 vmax=vmax, 216 annotation_format=format_string, 217 ) 218 219 _, mesh = random_heatmap.plot(ax) 220 221 colormap = mpl.cm.get_cmap(colormap_name) 222 for path, facecolor in zip(mesh.get_paths(), mesh.get_facecolors()): 223 vertices = path.vertices[0:4] 224 row = int(round(np.mean([v[1] for v in vertices]))) 225 col = int(round(np.mean([v[0] for v in vertices]))) 226 foo = row_col_map[(row, col)] 227 color_scale = (foo.value - vmin) / (vmax - vmin) 228 expected_color = np.array(colormap(color_scale)) 229 assert np.all(np.isclose(facecolor, expected_color)) 230 231 for artist in ax.get_children(): 232 if isinstance(artist, mpl.text.Text): 233 col, row = artist.get_position() 234 if (row, col) in test_value_map: 235 foo = test_value_map[(row, col)] 236 actual_text = artist.get_text() 237 expected_text = format(foo, format_string) 238 assert actual_text == expected_text 239 240 241@pytest.mark.parametrize( 242 'position,size,pad', 243 [ 244 ('right', "5%", "2%"), 245 ('right', "5%", "10%"), 246 ('right', "20%", "2%"), 247 ('right', "20%", "10%"), 248 ('left', "5%", "2%"), 249 ('left', "5%", "10%"), 250 ('left', "20%", "2%"), 251 ('left', "20%", "10%"), 252 ('top', "5%", "2%"), 253 ('top', "5%", "10%"), 254 ('top', "20%", "2%"), 255 ('top', "20%", "10%"), 256 ('bottom', "5%", "2%"), 257 ('bottom', "5%", "10%"), 258 ('bottom', "20%", "2%"), 259 ('bottom', "20%", "10%"), 260 ], 261) 262def test_colorbar(ax, position, size, pad): 263 row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) 264 qubits = [grid_qubit.GridQubit(row, col) for (row, col) in row_col_list] 265 values = np.random.random(len(qubits)) 266 test_value_map = {(qubit,): value for qubit, value in zip(qubits, values)} 267 random_heatmap = heatmap.Heatmap(test_value_map, plot_colorbar=False) 268 fig1, ax1 = plt.subplots() 269 random_heatmap.plot(ax1) 270 fig2, ax2 = plt.subplots() 271 random_heatmap.plot( 272 ax2, plot_colorbar=True, colorbar_position=position, colorbar_size=size, colorbar_pad=pad 273 ) 274 275 # We need to call savefig() explicitly for updating axes position since the figure 276 # object has been altered in the HeatMap._plot_colorbar function. 277 tmp_dir = mkdtemp() 278 fig2.savefig(pathlib.Path(tmp_dir) / 'tmp.png') 279 280 # Check that the figure has one more object in it when colorbar is on. 281 assert len(fig2.get_children()) == len(fig1.get_children()) + 1 282 283 fig_pos = fig2.get_axes()[0].get_position() 284 colorbar_pos = fig2.get_axes()[1].get_position() 285 286 origin_axes_size = ( 287 fig_pos.xmax - fig_pos.xmin 288 if position in ["left", "right"] 289 else fig_pos.ymax - fig_pos.ymin 290 ) 291 expected_pad = int(pad.replace("%", "")) / 100 * origin_axes_size 292 expected_size = int(size.replace("%", "")) / 100 * origin_axes_size 293 294 if position == "right": 295 pad_distance = colorbar_pos.xmin - fig_pos.xmax 296 colorbar_size = colorbar_pos.xmax - colorbar_pos.xmin 297 elif position == "left": 298 pad_distance = fig_pos.xmin - colorbar_pos.xmax 299 colorbar_size = colorbar_pos.xmax - colorbar_pos.xmin 300 elif position == "top": 301 pad_distance = colorbar_pos.ymin - fig_pos.ymax 302 colorbar_size = colorbar_pos.ymax - colorbar_pos.ymin 303 elif position == "bottom": 304 pad_distance = fig_pos.ymin - colorbar_pos.ymax 305 colorbar_size = colorbar_pos.ymax - colorbar_pos.ymin 306 307 assert np.isclose(colorbar_size, expected_size) 308 assert np.isclose(pad_distance, expected_pad) 309 310 plt.close(fig1) 311 plt.close(fig2) 312 313 314def test_plot_updates_local_config(): 315 value_map_2d = { 316 (grid_qubit.GridQubit(3, 2), grid_qubit.GridQubit(4, 2)): 0.004619111460557768, 317 (grid_qubit.GridQubit(4, 1), grid_qubit.GridQubit(4, 2)): 0.0076079162393482835, 318 } 319 value_map_1d = { 320 (grid_qubit.GridQubit(3, 2),): 0.004619111460557768, 321 (grid_qubit.GridQubit(4, 2),): 0.0076079162393482835, 322 } 323 original_title = "Two Qubit Interaction Heatmap" 324 new_title = "Temporary title for the plot" 325 for random_heatmap in [ 326 heatmap.TwoQubitInteractionHeatmap(value_map_2d, title=original_title), 327 heatmap.Heatmap(value_map_1d, title=original_title), 328 ]: 329 _, ax = plt.subplots() 330 random_heatmap.plot(ax, title=new_title) 331 assert ax.get_title() == new_title 332 _, ax = plt.subplots() 333 random_heatmap.plot(ax) 334 assert ax.get_title() == original_title 335