1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import contextlib 15import dataclasses 16import datetime 17import importlib 18import io 19import json 20import os 21import pathlib 22import sys 23import warnings 24from typing import Dict, List, Optional, Tuple 25from unittest import mock 26 27import numpy as np 28import pandas as pd 29import pytest 30import sympy 31 32import cirq 33from cirq._compat import proper_eq 34from cirq.protocols import json_serialization 35from cirq.testing.json import ModuleJsonTestSpec, spec_for, assert_json_roundtrip_works 36 37REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.parent 38 39 40@dataclasses.dataclass 41class _ModuleDeprecation: 42 old_name: str 43 deprecation_assertion: contextlib.AbstractContextManager 44 45 46# tested modules and their deprecation settings 47TESTED_MODULES: Dict[str, Optional[_ModuleDeprecation]] = { 48 'cirq_aqt': _ModuleDeprecation( 49 old_name="cirq.aqt", 50 deprecation_assertion=cirq.testing.assert_deprecated( 51 "cirq.aqt", deadline="v0.14", count=None 52 ), 53 ), 54 'cirq_ionq': _ModuleDeprecation( 55 old_name="cirq.ionq", 56 deprecation_assertion=cirq.testing.assert_deprecated( 57 "cirq.ionq", deadline="v0.14", count=None 58 ), 59 ), 60 'cirq_google': _ModuleDeprecation( 61 old_name="cirq.google", 62 deprecation_assertion=cirq.testing.assert_deprecated( 63 "cirq.google", deadline="v0.14", count=None 64 ), 65 ), 66 'cirq_pasqal': _ModuleDeprecation( 67 old_name="cirq.pasqal", 68 deprecation_assertion=cirq.testing.assert_deprecated( 69 "cirq.pasqal", deadline="v0.14", count=None 70 ), 71 ), 72 'cirq_rigetti': None, 73 'cirq.protocols': None, 74 'non_existent_should_be_fine': None, 75} 76 77 78# pyQuil 3.0, necessary for cirq_rigetti module requires 79# python >= 3.7 80if sys.version_info < (3, 7): # pragma: no cover 81 del TESTED_MODULES['cirq_rigetti'] 82 83 84def _get_testspecs_for_modules() -> List[ModuleJsonTestSpec]: 85 modules = [] 86 for m in TESTED_MODULES.keys(): 87 try: 88 modules.append(spec_for(m)) 89 except ModuleNotFoundError: 90 # for optional modules it is okay to skip 91 pass 92 return modules 93 94 95MODULE_TEST_SPECS = _get_testspecs_for_modules() 96 97 98def test_line_qubit_roundtrip(): 99 q1 = cirq.LineQubit(12) 100 assert_json_roundtrip_works( 101 q1, 102 text_should_be="""{ 103 "cirq_type": "LineQubit", 104 "x": 12 105}""", 106 ) 107 108 109def test_gridqubit_roundtrip(): 110 q = cirq.GridQubit(15, 18) 111 assert_json_roundtrip_works( 112 q, 113 text_should_be="""{ 114 "cirq_type": "GridQubit", 115 "row": 15, 116 "col": 18 117}""", 118 ) 119 120 121def test_op_roundtrip(): 122 q = cirq.LineQubit(5) 123 op1 = cirq.rx(0.123).on(q) 124 assert_json_roundtrip_works( 125 op1, 126 text_should_be="""{ 127 "cirq_type": "GateOperation", 128 "gate": { 129 "cirq_type": "Rx", 130 "rads": 0.123 131 }, 132 "qubits": [ 133 { 134 "cirq_type": "LineQubit", 135 "x": 5 136 } 137 ] 138}""", 139 ) 140 141 142def test_op_roundtrip_filename(tmpdir): 143 filename = f'{tmpdir}/op.json' 144 q = cirq.LineQubit(5) 145 op1 = cirq.rx(0.123).on(q) 146 cirq.to_json(op1, filename) 147 assert os.path.exists(filename) 148 op2 = cirq.read_json(filename) 149 assert op1 == op2 150 151 gzip_filename = f'{tmpdir}/op.gz' 152 cirq.to_json_gzip(op1, gzip_filename) 153 assert os.path.exists(gzip_filename) 154 op3 = cirq.read_json_gzip(gzip_filename) 155 assert op1 == op3 156 157 158def test_op_roundtrip_file_obj(tmpdir): 159 filename = f'{tmpdir}/op.json' 160 q = cirq.LineQubit(5) 161 op1 = cirq.rx(0.123).on(q) 162 with open(filename, 'w+') as file: 163 cirq.to_json(op1, file) 164 assert os.path.exists(filename) 165 file.seek(0) 166 op2 = cirq.read_json(file) 167 assert op1 == op2 168 169 gzip_filename = f'{tmpdir}/op.gz' 170 with open(gzip_filename, 'w+b') as gzip_file: 171 cirq.to_json_gzip(op1, gzip_file) 172 assert os.path.exists(gzip_filename) 173 gzip_file.seek(0) 174 op3 = cirq.read_json_gzip(gzip_file) 175 assert op1 == op3 176 177 178def test_fail_to_resolve(): 179 buffer = io.StringIO() 180 buffer.write( 181 """ 182 { 183 "cirq_type": "MyCustomClass", 184 "data": [1, 2, 3] 185 } 186 """ 187 ) 188 buffer.seek(0) 189 190 with pytest.raises(ValueError) as e: 191 cirq.read_json(buffer) 192 assert e.match("Could not resolve type 'MyCustomClass' during deserialization") 193 194 195QUBITS = cirq.LineQubit.range(5) 196Q0, Q1, Q2, Q3, Q4 = QUBITS 197 198# TODO: Include cirq.rx in the Circuit test case file. 199# Github issue: https://github.com/quantumlib/Cirq/issues/2014 200# Note that even the following doesn't work because theta gets 201# multiplied by 1/pi: 202# cirq.Circuit(cirq.rx(sympy.Symbol('theta')).on(Q0)), 203 204### MODULE CONSISTENCY tests 205 206 207@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) 208# during test setup deprecated submodules are inspected and trigger the 209# deprecation error in testing. It is cleaner to just turn it off than to assert 210# deprecation for each submodule. 211@mock.patch.dict(os.environ, clear='CIRQ_TESTING') 212def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec): 213 # everything in the list should be ignored for a reason 214 names = set(mod_spec.get_all_names()) 215 missing_names = set(mod_spec.should_not_be_serialized).difference(names) 216 assert len(missing_names) == 0, ( 217 f"Defined as \"should't be serialized\", " 218 f"but missing from {mod_spec}: \n" 219 f"{missing_names}" 220 ) 221 222 223@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) 224# during test setup deprecated submodules are inspected and trigger the 225# deprecation error in testing. It is cleaner to just turn it off than to assert 226# deprecation for each submodule. 227@mock.patch.dict(os.environ, clear='CIRQ_TESTING') 228def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec): 229 # everything in the list should be ignored for a reason 230 names = set(mod_spec.get_all_names()) 231 missing_names = set(mod_spec.not_yet_serializable).difference(names) 232 assert ( 233 len(missing_names) == 0 234 ), f"Defined as Not yet serializable, but missing from {mod_spec}: \n{missing_names}" 235 236 237@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) 238def test_mutually_exclusive_blacklist(mod_spec: ModuleJsonTestSpec): 239 common = set(mod_spec.should_not_be_serialized) & set(mod_spec.not_yet_serializable) 240 assert len(common) == 0, ( 241 f"Defined in both {mod_spec.name} 'Not yet serializable' " 242 f" and 'Should not be serialized' lists: {common}" 243 ) 244 245 246@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) 247def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec): 248 resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()]) 249 common = set(mod_spec.should_not_be_serialized) & resolver_cache_types 250 251 assert len(common) == 0, ( 252 f"Defined in both {mod_spec.name} Resolver " 253 f"Cache and should not be serialized:" 254 f"{common}" 255 ) 256 257 258@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) 259def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec): 260 resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()]) 261 common = set(mod_spec.not_yet_serializable) & resolver_cache_types 262 263 assert len(common) == 0, ( 264 f"Issue with the JSON config of {mod_spec.name}.\n" 265 f"Types are listed in both" 266 f" {mod_spec.name}.json_resolver_cache.py and in the 'not_yet_serializable' list in" 267 f" {mod_spec.test_data_path}/spec.py: " 268 f"\n {common}" 269 ) 270 271 272def test_builtins(): 273 assert_json_roundtrip_works(True) 274 assert_json_roundtrip_works(1) 275 assert_json_roundtrip_works(1 + 2j) 276 assert_json_roundtrip_works( 277 { 278 'test': [123, 5.5], 279 'key2': 'asdf', 280 '3': None, 281 '0.0': [], 282 } 283 ) 284 285 286def test_numpy(): 287 x = np.ones(1)[0] 288 289 assert_json_roundtrip_works(x.astype(np.int8)) 290 assert_json_roundtrip_works(x.astype(np.int16)) 291 assert_json_roundtrip_works(x.astype(np.int32)) 292 assert_json_roundtrip_works(x.astype(np.int64)) 293 assert_json_roundtrip_works(x.astype(np.uint8)) 294 assert_json_roundtrip_works(x.astype(np.uint16)) 295 assert_json_roundtrip_works(x.astype(np.uint32)) 296 assert_json_roundtrip_works(x.astype(np.uint64)) 297 assert_json_roundtrip_works(x.astype(np.float32)) 298 assert_json_roundtrip_works(x.astype(np.float64)) 299 assert_json_roundtrip_works(x.astype(np.complex64)) 300 assert_json_roundtrip_works(x.astype(np.complex128)) 301 302 assert_json_roundtrip_works(np.ones((11, 5))) 303 assert_json_roundtrip_works(np.arange(3)) 304 305 306def test_pandas(): 307 assert_json_roundtrip_works( 308 pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=['x', 'y', 'z'], index=[2, 5]) 309 ) 310 assert_json_roundtrip_works(pd.Index([1, 2, 3], name='test')) 311 assert_json_roundtrip_works( 312 pd.MultiIndex.from_tuples([(1, 2), (3, 4), (5, 6)], names=['alice', 'bob']) 313 ) 314 315 assert_json_roundtrip_works( 316 pd.DataFrame( 317 index=pd.Index([1, 2, 3], name='test'), 318 data=[[11, 21.0], [12, 22.0], [13, 23.0]], 319 columns=['a', 'b'], 320 ) 321 ) 322 assert_json_roundtrip_works( 323 pd.DataFrame( 324 index=pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 4)], names=['x', 'y']), 325 data=[[11, 21.0], [12, 22.0], [13, 23.0]], 326 columns=pd.Index(['a', 'b'], name='c'), 327 ) 328 ) 329 330 331def test_sympy(): 332 # Raw values. 333 assert_json_roundtrip_works(sympy.Symbol('theta')) 334 assert_json_roundtrip_works(sympy.Integer(5)) 335 assert_json_roundtrip_works(sympy.Rational(2, 3)) 336 assert_json_roundtrip_works(sympy.Float(1.1)) 337 338 # Basic operations. 339 s = sympy.Symbol('s') 340 t = sympy.Symbol('t') 341 assert_json_roundtrip_works(t + s) 342 assert_json_roundtrip_works(t * s) 343 assert_json_roundtrip_works(t / s) 344 assert_json_roundtrip_works(t - s) 345 assert_json_roundtrip_works(t ** s) 346 347 # Linear combinations. 348 assert_json_roundtrip_works(t * 2) 349 assert_json_roundtrip_works(4 * t + 3 * s + 2) 350 351 assert_json_roundtrip_works(sympy.pi) 352 assert_json_roundtrip_works(sympy.E) 353 assert_json_roundtrip_works(sympy.EulerGamma) 354 355 356class SBKImpl(cirq.SerializableByKey): 357 """A test implementation of SerializableByKey.""" 358 359 def __init__( 360 self, 361 name: str, 362 data_list: Optional[List] = None, 363 data_tuple: Optional[Tuple] = None, 364 data_dict: Optional[Dict] = None, 365 ): 366 self.name = name 367 self.data_list = data_list or [] 368 self.data_tuple = data_tuple or () 369 self.data_dict = data_dict or {} 370 371 def __eq__(self, other): 372 if not isinstance(other, SBKImpl): 373 return False 374 return ( 375 self.name == other.name 376 and self.data_list == other.data_list 377 and self.data_tuple == other.data_tuple 378 and self.data_dict == other.data_dict 379 ) 380 381 def _json_dict_(self): 382 return { 383 "cirq_type": "SBKImpl", 384 "name": self.name, 385 "data_list": self.data_list, 386 "data_tuple": self.data_tuple, 387 "data_dict": self.data_dict, 388 } 389 390 @classmethod 391 def _from_json_dict_(cls, name, data_list, data_tuple, data_dict, **kwargs): 392 return cls(name, data_list, tuple(data_tuple), data_dict) 393 394 395def test_context_serialization(): 396 def custom_resolver(name): 397 if name == 'SBKImpl': 398 return SBKImpl 399 400 test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS 401 402 sbki_empty = SBKImpl('sbki_empty') 403 assert_json_roundtrip_works(sbki_empty, resolvers=test_resolvers) 404 405 sbki_list = SBKImpl('sbki_list', data_list=[sbki_empty, sbki_empty]) 406 assert_json_roundtrip_works(sbki_list, resolvers=test_resolvers) 407 408 sbki_tuple = SBKImpl('sbki_tuple', data_tuple=(sbki_list, sbki_list)) 409 assert_json_roundtrip_works(sbki_tuple, resolvers=test_resolvers) 410 411 sbki_dict = SBKImpl('sbki_dict', data_dict={'a': sbki_tuple, 'b': sbki_tuple}) 412 assert_json_roundtrip_works(sbki_dict, resolvers=test_resolvers) 413 414 sbki_json = str(cirq.to_json(sbki_dict)) 415 # There should be exactly one context item for each previous SBKImpl. 416 assert sbki_json.count('"cirq_type": "_SerializedContext"') == 4 417 # There should be exactly two key items for each of sbki_(empty|list|tuple), 418 # plus one for the top-level sbki_dict. 419 assert sbki_json.count('"cirq_type": "_SerializedKey"') == 7 420 # The final object should be a _SerializedKey for sbki_dict. 421 final_obj_idx = sbki_json.rfind('{') 422 final_obj = sbki_json[final_obj_idx : sbki_json.find('}', final_obj_idx) + 1] 423 assert ( 424 final_obj 425 == """{ 426 "cirq_type": "_SerializedKey", 427 "key": 4 428 }""" 429 ) 430 431 list_sbki = [sbki_dict] 432 assert_json_roundtrip_works(list_sbki, resolvers=test_resolvers) 433 434 dict_sbki = {'a': sbki_dict} 435 assert_json_roundtrip_works(dict_sbki, resolvers=test_resolvers) 436 437 assert sbki_list != json_serialization._SerializedKey(sbki_list) 438 439 # Serialization keys have unique suffixes. 440 sbki_other_list = SBKImpl('sbki_list', data_list=[sbki_list]) 441 assert_json_roundtrip_works(sbki_other_list, resolvers=test_resolvers) 442 443 444def test_internal_serializer_types(): 445 sbki = SBKImpl('test_key') 446 key = 1 447 test_key = json_serialization._SerializedKey(key) 448 test_context = json_serialization._SerializedContext(sbki, 1) 449 test_serialization = json_serialization._ContextualSerialization(sbki) 450 451 key_json = test_key._json_dict_() 452 with pytest.raises(TypeError, match='_from_json_dict_'): 453 _ = json_serialization._SerializedKey._from_json_dict_(**key_json) 454 455 context_json = test_context._json_dict_() 456 with pytest.raises(TypeError, match='_from_json_dict_'): 457 _ = json_serialization._SerializedContext._from_json_dict_(**context_json) 458 459 serialization_json = test_serialization._json_dict_() 460 with pytest.raises(TypeError, match='_from_json_dict_'): 461 _ = json_serialization._ContextualSerialization._from_json_dict_(**serialization_json) 462 463 464# during test setup deprecated submodules are inspected and trigger the 465# deprecation error in testing. It is cleaner to just turn it off than to assert 466# deprecation for each submodule. 467@mock.patch.dict(os.environ, clear='CIRQ_TESTING') 468def _list_public_classes_for_tested_modules(): 469 # to remove DeprecationWarning noise during test collection 470 with warnings.catch_warnings(): 471 warnings.simplefilter("ignore") 472 return [ 473 (mod_spec, o, n) 474 for mod_spec in MODULE_TEST_SPECS 475 for (o, n) in mod_spec.find_classes_that_should_serialize() 476 ] 477 478 479@pytest.mark.parametrize( 480 'mod_spec,cirq_obj_name,cls', 481 _list_public_classes_for_tested_modules(), 482) 483def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls): 484 if cirq_obj_name in mod_spec.tested_elsewhere: 485 pytest.skip("Tested elsewhere.") 486 487 if cirq_obj_name in mod_spec.not_yet_serializable: 488 return pytest.xfail(reason="Not serializable (yet)") 489 490 test_data_path = mod_spec.test_data_path 491 rel_path = test_data_path.relative_to(REPO_ROOT) 492 mod_path = mod_spec.name.replace(".", "/") 493 rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py" 494 json_path = test_data_path / f'{cirq_obj_name}.json' 495 json_path2 = test_data_path / f'{cirq_obj_name}.json_inward' 496 deprecation_deadline = mod_spec.deprecated.get(cirq_obj_name) 497 498 if not json_path.exists() and not json_path2.exists(): 499 # coverage: ignore 500 pytest.fail( 501 f"Hello intrepid developer. There is a new public or " 502 f"serializable object named '{cirq_obj_name}' in the module '{mod_spec.name}' " 503 f"that does not have associated test data.\n" 504 f"\n" 505 f"You must create the file\n" 506 f" {rel_path}/{cirq_obj_name}.json\n" 507 f"and the file\n" 508 f" {rel_path}/{cirq_obj_name}.repr\n" 509 f"in order to guarantee this public object is, and will " 510 f"remain, serializable.\n" 511 f"\n" 512 f"The content of the .repr file should be the string returned " 513 f"by `repr(obj)` where `obj` is a test {cirq_obj_name} value " 514 f"or list of such values. To get this to work you may need to " 515 f"implement a __repr__ method for {cirq_obj_name}. The repr " 516 f"must be a parsable python expression that evaluates to " 517 f"something equal to `obj`." 518 f"\n" 519 f"The content of the .json file should be the string returned " 520 f"by `cirq.to_json(obj)` where `obj` is the same object or " 521 f"list of test objects.\n" 522 f"To get this to work you likely need " 523 f"to add {cirq_obj_name} to the " 524 f"`_class_resolver_dictionary` method in " 525 f"the {rel_resolver_cache_path} source file. " 526 f"You may also need to add a _json_dict_ method to " 527 f"{cirq_obj_name}. In some cases you will also need to add a " 528 f"_from_json_dict_ class method to the {cirq_obj_name} class." 529 f"\n" 530 f"For more information on JSON serialization, please read the " 531 f"docstring for cirq.protocols.SupportsJSON. If this object or " 532 f"class is not appropriate for serialization, add its name to " 533 f"the `should_not_be_serialized` list in the TestSpec defined in the " 534 f"{rel_path}/spec.py source file." 535 ) 536 537 repr_file = test_data_path / f'{cirq_obj_name}.repr' 538 if repr_file.exists() and cls is not None: 539 objs = _eval_repr_data_file(repr_file, deprecation_deadline=deprecation_deadline) 540 if not isinstance(objs, list): 541 objs = [objs] 542 543 for obj in objs: 544 assert type(obj) == cls, ( 545 f"Value in {test_data_path}/{cirq_obj_name}.repr must be of " 546 f"exact type {cls}, or a list of instances of that type. But " 547 f"the value (or one of the list entries) had type " 548 f"{type(obj)}.\n" 549 f"\n" 550 f"If using a value of the wrong type is intended, move the " 551 f"value to {test_data_path}/{cirq_obj_name}.repr_inward\n" 552 f"\n" 553 f"Value with wrong type:\n{obj!r}." 554 ) 555 556 557def test_to_from_strings(): 558 x_json_text = """{ 559 "cirq_type": "_PauliX", 560 "exponent": 1.0, 561 "global_shift": 0.0 562}""" 563 assert cirq.to_json(cirq.X) == x_json_text 564 assert cirq.read_json(json_text=x_json_text) == cirq.X 565 566 with pytest.raises(ValueError, match='specify ONE'): 567 cirq.read_json(io.StringIO(), json_text=x_json_text) 568 569 570def test_to_from_json_gzip(): 571 a, b = cirq.LineQubit.range(2) 572 test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b)) 573 gzip_data = cirq.to_json_gzip(test_circuit) 574 unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data) 575 assert test_circuit == unzip_circuit 576 577 with pytest.raises(ValueError): 578 _ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data) 579 with pytest.raises(ValueError): 580 _ = cirq.read_json_gzip() 581 582 583def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str]): 584 content = path.read_text() 585 ctx_managers: List[contextlib.AbstractContextManager] = [contextlib.suppress()] 586 if deprecation_deadline: 587 # we ignore coverage here, because sometimes there are no deprecations at all in any of the 588 # modules 589 # coverage: ignore 590 ctx_managers = [cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None)] 591 592 for deprecation in TESTED_MODULES.values(): 593 if deprecation is not None and deprecation.old_name in content: 594 ctx_managers.append(deprecation.deprecation_assertion) 595 596 imports = { 597 'cirq': cirq, 598 'pd': pd, 599 'sympy': sympy, 600 'np': np, 601 'datetime': datetime, 602 } 603 604 for m in TESTED_MODULES.keys(): 605 try: 606 imports[m] = importlib.import_module(m) 607 except ImportError: 608 pass 609 610 with contextlib.ExitStack() as stack: 611 for ctx_manager in ctx_managers: 612 stack.enter_context(ctx_manager) 613 obj = eval( 614 content, 615 imports, 616 {}, 617 ) 618 return obj 619 620 621def assert_repr_and_json_test_data_agree( 622 mod_spec: ModuleJsonTestSpec, 623 repr_path: pathlib.Path, 624 json_path: pathlib.Path, 625 inward_only: bool, 626 deprecation_deadline: Optional[str], 627): 628 if not repr_path.exists() and not json_path.exists(): 629 return 630 631 rel_repr_path = f'{repr_path.relative_to(REPO_ROOT)}' 632 rel_json_path = f'{json_path.relative_to(REPO_ROOT)}' 633 634 try: 635 json_from_file = json_path.read_text() 636 ctx_manager = ( 637 cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None) 638 if deprecation_deadline 639 else contextlib.suppress() 640 ) 641 with ctx_manager: 642 json_obj = cirq.read_json(json_text=json_from_file) 643 except ValueError as ex: # coverage: ignore 644 # coverage: ignore 645 if "Could not resolve type" in str(ex): 646 mod_path = mod_spec.name.replace(".", "/") 647 rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py" 648 # coverage: ignore 649 pytest.fail( 650 f"{rel_json_path} can't be parsed to JSON.\n" 651 f"Maybe an entry is missing from the " 652 f" `_class_resolver_dictionary` method in {rel_resolver_cache_path}?" 653 ) 654 else: 655 raise ValueError(f"deprecation: {deprecation_deadline} - got error: {ex}") 656 except AssertionError as ex: # coverage: ignore 657 # coverage: ignore 658 raise ex 659 except Exception as ex: # coverage: ignore 660 # coverage: ignore 661 raise IOError(f'Failed to parse test json data from {rel_json_path}.') from ex 662 663 try: 664 repr_obj = _eval_repr_data_file(repr_path, deprecation_deadline) 665 except Exception as ex: # coverage: ignore 666 # coverage: ignore 667 raise IOError(f'Failed to parse test repr data from {rel_repr_path}.') from ex 668 669 assert proper_eq(json_obj, repr_obj), ( 670 f'The json data from {rel_json_path} did not parse ' 671 f'into an object equivalent to the repr data from {rel_repr_path}.\n' 672 f'\n' 673 f'json object: {json_obj!r}\n' 674 f'repr object: {repr_obj!r}\n' 675 ) 676 677 if not inward_only: 678 json_from_cirq = cirq.to_json(repr_obj) 679 json_from_cirq_obj = json.loads(json_from_cirq) 680 json_from_file_obj = json.loads(json_from_file) 681 682 assert json_from_cirq_obj == json_from_file_obj, ( 683 f'The json produced by cirq no longer agrees with the json in the ' 684 f'{rel_json_path} test data file.\n' 685 f'\n' 686 f'You must either fix the cirq code to continue to produce the ' 687 f'same output, or you must move the old test data to ' 688 f'{rel_json_path}_inward and create a fresh {rel_json_path} file.\n' 689 f'\n' 690 f'test data json:\n' 691 f'{json_from_file}\n' 692 f'\n' 693 f'cirq produced json:\n' 694 f'{json_from_cirq}\n' 695 ) 696 697 698@pytest.mark.parametrize( 699 'mod_spec, abs_path', 700 [(m, abs_path) for m in MODULE_TEST_SPECS for abs_path in m.all_test_data_keys()], 701) 702def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str): 703 assert_repr_and_json_test_data_agree( 704 mod_spec=mod_spec, 705 repr_path=pathlib.Path(f'{abs_path}.repr'), 706 json_path=pathlib.Path(f'{abs_path}.json'), 707 inward_only=False, 708 deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)), 709 ) 710 assert_repr_and_json_test_data_agree( 711 mod_spec=mod_spec, 712 repr_path=pathlib.Path(f'{abs_path}.repr_inward'), 713 json_path=pathlib.Path(f'{abs_path}.json_inward'), 714 inward_only=True, 715 deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)), 716 ) 717 718 719def test_pathlib_paths(tmpdir): 720 path = pathlib.Path(tmpdir) / 'op.json' 721 cirq.to_json(cirq.X, path) 722 assert cirq.read_json(path) == cirq.X 723 724 gzip_path = pathlib.Path(tmpdir) / 'op.gz' 725 cirq.to_json_gzip(cirq.X, gzip_path) 726 assert cirq.read_json_gzip(gzip_path) == cirq.X 727 728 729def test_json_serializable_dataclass(): 730 @cirq.json_serializable_dataclass 731 class MyDC: 732 q: cirq.LineQubit 733 desc: str 734 735 my_dc = MyDC(cirq.LineQubit(4), 'hi mom') 736 737 def custom_resolver(name): 738 if name == 'MyDC': 739 return MyDC 740 741 assert_json_roundtrip_works( 742 my_dc, 743 text_should_be="\n".join( 744 [ 745 '{', 746 ' "cirq_type": "MyDC",', 747 ' "q": {', 748 ' "cirq_type": "LineQubit",', 749 ' "x": 4', 750 ' },', 751 ' "desc": "hi mom"', 752 '}', 753 ] 754 ), 755 resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS, 756 ) 757 758 759def test_json_serializable_dataclass_parenthesis(): 760 @cirq.json_serializable_dataclass() 761 class MyDC: 762 q: cirq.LineQubit 763 desc: str 764 765 def custom_resolver(name): 766 if name == 'MyDC': 767 return MyDC 768 769 my_dc = MyDC(cirq.LineQubit(4), 'hi mom') 770 771 assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS) 772 773 774def test_dataclass_json_dict(): 775 @dataclasses.dataclass(frozen=True) 776 class MyDC: 777 q: cirq.LineQubit 778 desc: str 779 780 def _json_dict_(self): 781 return cirq.dataclass_json_dict(self) 782 783 def custom_resolver(name): 784 if name == 'MyDC': 785 return MyDC 786 787 my_dc = MyDC(cirq.LineQubit(4), 'hi mom') 788 789 assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver, *cirq.DEFAULT_RESOLVERS]) 790 791 792def test_json_serializable_dataclass_namespace(): 793 @cirq.json_serializable_dataclass(namespace='cirq.experiments') 794 class QuantumVolumeParams: 795 width: int 796 depth: int 797 circuit_i: int 798 799 qvp = QuantumVolumeParams(width=5, depth=5, circuit_i=0) 800 801 def custom_resolver(name): 802 if name == 'cirq.experiments.QuantumVolumeParams': 803 return QuantumVolumeParams 804 805 assert_json_roundtrip_works(qvp, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS) 806 807 808def test_numpy_values(): 809 assert ( 810 cirq.to_json({'value': np.array(1)}) 811 == """{ 812 "value": 1 813}""" 814 ) 815