1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import contextlib
15import dataclasses
16import datetime
17import importlib
18import io
19import json
20import os
21import pathlib
22import sys
23import warnings
24from typing import Dict, List, Optional, Tuple
25from unittest import mock
26
27import numpy as np
28import pandas as pd
29import pytest
30import sympy
31
32import cirq
33from cirq._compat import proper_eq
34from cirq.protocols import json_serialization
35from cirq.testing.json import ModuleJsonTestSpec, spec_for, assert_json_roundtrip_works
36
37REPO_ROOT = pathlib.Path(__file__).parent.parent.parent.parent
38
39
40@dataclasses.dataclass
41class _ModuleDeprecation:
42    old_name: str
43    deprecation_assertion: contextlib.AbstractContextManager
44
45
46# tested modules and their deprecation settings
47TESTED_MODULES: Dict[str, Optional[_ModuleDeprecation]] = {
48    'cirq_aqt': _ModuleDeprecation(
49        old_name="cirq.aqt",
50        deprecation_assertion=cirq.testing.assert_deprecated(
51            "cirq.aqt", deadline="v0.14", count=None
52        ),
53    ),
54    'cirq_ionq': _ModuleDeprecation(
55        old_name="cirq.ionq",
56        deprecation_assertion=cirq.testing.assert_deprecated(
57            "cirq.ionq", deadline="v0.14", count=None
58        ),
59    ),
60    'cirq_google': _ModuleDeprecation(
61        old_name="cirq.google",
62        deprecation_assertion=cirq.testing.assert_deprecated(
63            "cirq.google", deadline="v0.14", count=None
64        ),
65    ),
66    'cirq_pasqal': _ModuleDeprecation(
67        old_name="cirq.pasqal",
68        deprecation_assertion=cirq.testing.assert_deprecated(
69            "cirq.pasqal", deadline="v0.14", count=None
70        ),
71    ),
72    'cirq_rigetti': None,
73    'cirq.protocols': None,
74    'non_existent_should_be_fine': None,
75}
76
77
78# pyQuil 3.0, necessary for cirq_rigetti module requires
79# python >= 3.7
80if sys.version_info < (3, 7):  # pragma: no cover
81    del TESTED_MODULES['cirq_rigetti']
82
83
84def _get_testspecs_for_modules() -> List[ModuleJsonTestSpec]:
85    modules = []
86    for m in TESTED_MODULES.keys():
87        try:
88            modules.append(spec_for(m))
89        except ModuleNotFoundError:
90            # for optional modules it is okay to skip
91            pass
92    return modules
93
94
95MODULE_TEST_SPECS = _get_testspecs_for_modules()
96
97
98def test_line_qubit_roundtrip():
99    q1 = cirq.LineQubit(12)
100    assert_json_roundtrip_works(
101        q1,
102        text_should_be="""{
103  "cirq_type": "LineQubit",
104  "x": 12
105}""",
106    )
107
108
109def test_gridqubit_roundtrip():
110    q = cirq.GridQubit(15, 18)
111    assert_json_roundtrip_works(
112        q,
113        text_should_be="""{
114  "cirq_type": "GridQubit",
115  "row": 15,
116  "col": 18
117}""",
118    )
119
120
121def test_op_roundtrip():
122    q = cirq.LineQubit(5)
123    op1 = cirq.rx(0.123).on(q)
124    assert_json_roundtrip_works(
125        op1,
126        text_should_be="""{
127  "cirq_type": "GateOperation",
128  "gate": {
129    "cirq_type": "Rx",
130    "rads": 0.123
131  },
132  "qubits": [
133    {
134      "cirq_type": "LineQubit",
135      "x": 5
136    }
137  ]
138}""",
139    )
140
141
142def test_op_roundtrip_filename(tmpdir):
143    filename = f'{tmpdir}/op.json'
144    q = cirq.LineQubit(5)
145    op1 = cirq.rx(0.123).on(q)
146    cirq.to_json(op1, filename)
147    assert os.path.exists(filename)
148    op2 = cirq.read_json(filename)
149    assert op1 == op2
150
151    gzip_filename = f'{tmpdir}/op.gz'
152    cirq.to_json_gzip(op1, gzip_filename)
153    assert os.path.exists(gzip_filename)
154    op3 = cirq.read_json_gzip(gzip_filename)
155    assert op1 == op3
156
157
158def test_op_roundtrip_file_obj(tmpdir):
159    filename = f'{tmpdir}/op.json'
160    q = cirq.LineQubit(5)
161    op1 = cirq.rx(0.123).on(q)
162    with open(filename, 'w+') as file:
163        cirq.to_json(op1, file)
164        assert os.path.exists(filename)
165        file.seek(0)
166        op2 = cirq.read_json(file)
167        assert op1 == op2
168
169    gzip_filename = f'{tmpdir}/op.gz'
170    with open(gzip_filename, 'w+b') as gzip_file:
171        cirq.to_json_gzip(op1, gzip_file)
172        assert os.path.exists(gzip_filename)
173        gzip_file.seek(0)
174        op3 = cirq.read_json_gzip(gzip_file)
175        assert op1 == op3
176
177
178def test_fail_to_resolve():
179    buffer = io.StringIO()
180    buffer.write(
181        """
182    {
183      "cirq_type": "MyCustomClass",
184      "data": [1, 2, 3]
185    }
186    """
187    )
188    buffer.seek(0)
189
190    with pytest.raises(ValueError) as e:
191        cirq.read_json(buffer)
192    assert e.match("Could not resolve type 'MyCustomClass' during deserialization")
193
194
195QUBITS = cirq.LineQubit.range(5)
196Q0, Q1, Q2, Q3, Q4 = QUBITS
197
198# TODO: Include cirq.rx in the Circuit test case file.
199# Github issue: https://github.com/quantumlib/Cirq/issues/2014
200# Note that even the following doesn't work because theta gets
201# multiplied by 1/pi:
202#   cirq.Circuit(cirq.rx(sympy.Symbol('theta')).on(Q0)),
203
204### MODULE CONSISTENCY tests
205
206
207@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
208# during test setup deprecated submodules are inspected and trigger the
209# deprecation error in testing. It is cleaner to just turn it off than to assert
210# deprecation for each submodule.
211@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
212def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec):
213    # everything in the list should be ignored for a reason
214    names = set(mod_spec.get_all_names())
215    missing_names = set(mod_spec.should_not_be_serialized).difference(names)
216    assert len(missing_names) == 0, (
217        f"Defined as \"should't be serialized\", "
218        f"but missing from {mod_spec}: \n"
219        f"{missing_names}"
220    )
221
222
223@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
224# during test setup deprecated submodules are inspected and trigger the
225# deprecation error in testing. It is cleaner to just turn it off than to assert
226# deprecation for each submodule.
227@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
228def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec):
229    # everything in the list should be ignored for a reason
230    names = set(mod_spec.get_all_names())
231    missing_names = set(mod_spec.not_yet_serializable).difference(names)
232    assert (
233        len(missing_names) == 0
234    ), f"Defined as Not yet serializable, but missing from {mod_spec}: \n{missing_names}"
235
236
237@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
238def test_mutually_exclusive_blacklist(mod_spec: ModuleJsonTestSpec):
239    common = set(mod_spec.should_not_be_serialized) & set(mod_spec.not_yet_serializable)
240    assert len(common) == 0, (
241        f"Defined in both {mod_spec.name} 'Not yet serializable' "
242        f" and 'Should not be serialized' lists: {common}"
243    )
244
245
246@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
247def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec):
248    resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()])
249    common = set(mod_spec.should_not_be_serialized) & resolver_cache_types
250
251    assert len(common) == 0, (
252        f"Defined in both {mod_spec.name} Resolver "
253        f"Cache and should not be serialized:"
254        f"{common}"
255    )
256
257
258@pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr)
259def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec):
260    resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()])
261    common = set(mod_spec.not_yet_serializable) & resolver_cache_types
262
263    assert len(common) == 0, (
264        f"Issue with the JSON config of {mod_spec.name}.\n"
265        f"Types are listed in both"
266        f" {mod_spec.name}.json_resolver_cache.py and in the 'not_yet_serializable' list in"
267        f" {mod_spec.test_data_path}/spec.py: "
268        f"\n {common}"
269    )
270
271
272def test_builtins():
273    assert_json_roundtrip_works(True)
274    assert_json_roundtrip_works(1)
275    assert_json_roundtrip_works(1 + 2j)
276    assert_json_roundtrip_works(
277        {
278            'test': [123, 5.5],
279            'key2': 'asdf',
280            '3': None,
281            '0.0': [],
282        }
283    )
284
285
286def test_numpy():
287    x = np.ones(1)[0]
288
289    assert_json_roundtrip_works(x.astype(np.int8))
290    assert_json_roundtrip_works(x.astype(np.int16))
291    assert_json_roundtrip_works(x.astype(np.int32))
292    assert_json_roundtrip_works(x.astype(np.int64))
293    assert_json_roundtrip_works(x.astype(np.uint8))
294    assert_json_roundtrip_works(x.astype(np.uint16))
295    assert_json_roundtrip_works(x.astype(np.uint32))
296    assert_json_roundtrip_works(x.astype(np.uint64))
297    assert_json_roundtrip_works(x.astype(np.float32))
298    assert_json_roundtrip_works(x.astype(np.float64))
299    assert_json_roundtrip_works(x.astype(np.complex64))
300    assert_json_roundtrip_works(x.astype(np.complex128))
301
302    assert_json_roundtrip_works(np.ones((11, 5)))
303    assert_json_roundtrip_works(np.arange(3))
304
305
306def test_pandas():
307    assert_json_roundtrip_works(
308        pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=['x', 'y', 'z'], index=[2, 5])
309    )
310    assert_json_roundtrip_works(pd.Index([1, 2, 3], name='test'))
311    assert_json_roundtrip_works(
312        pd.MultiIndex.from_tuples([(1, 2), (3, 4), (5, 6)], names=['alice', 'bob'])
313    )
314
315    assert_json_roundtrip_works(
316        pd.DataFrame(
317            index=pd.Index([1, 2, 3], name='test'),
318            data=[[11, 21.0], [12, 22.0], [13, 23.0]],
319            columns=['a', 'b'],
320        )
321    )
322    assert_json_roundtrip_works(
323        pd.DataFrame(
324            index=pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 4)], names=['x', 'y']),
325            data=[[11, 21.0], [12, 22.0], [13, 23.0]],
326            columns=pd.Index(['a', 'b'], name='c'),
327        )
328    )
329
330
331def test_sympy():
332    # Raw values.
333    assert_json_roundtrip_works(sympy.Symbol('theta'))
334    assert_json_roundtrip_works(sympy.Integer(5))
335    assert_json_roundtrip_works(sympy.Rational(2, 3))
336    assert_json_roundtrip_works(sympy.Float(1.1))
337
338    # Basic operations.
339    s = sympy.Symbol('s')
340    t = sympy.Symbol('t')
341    assert_json_roundtrip_works(t + s)
342    assert_json_roundtrip_works(t * s)
343    assert_json_roundtrip_works(t / s)
344    assert_json_roundtrip_works(t - s)
345    assert_json_roundtrip_works(t ** s)
346
347    # Linear combinations.
348    assert_json_roundtrip_works(t * 2)
349    assert_json_roundtrip_works(4 * t + 3 * s + 2)
350
351    assert_json_roundtrip_works(sympy.pi)
352    assert_json_roundtrip_works(sympy.E)
353    assert_json_roundtrip_works(sympy.EulerGamma)
354
355
356class SBKImpl(cirq.SerializableByKey):
357    """A test implementation of SerializableByKey."""
358
359    def __init__(
360        self,
361        name: str,
362        data_list: Optional[List] = None,
363        data_tuple: Optional[Tuple] = None,
364        data_dict: Optional[Dict] = None,
365    ):
366        self.name = name
367        self.data_list = data_list or []
368        self.data_tuple = data_tuple or ()
369        self.data_dict = data_dict or {}
370
371    def __eq__(self, other):
372        if not isinstance(other, SBKImpl):
373            return False
374        return (
375            self.name == other.name
376            and self.data_list == other.data_list
377            and self.data_tuple == other.data_tuple
378            and self.data_dict == other.data_dict
379        )
380
381    def _json_dict_(self):
382        return {
383            "cirq_type": "SBKImpl",
384            "name": self.name,
385            "data_list": self.data_list,
386            "data_tuple": self.data_tuple,
387            "data_dict": self.data_dict,
388        }
389
390    @classmethod
391    def _from_json_dict_(cls, name, data_list, data_tuple, data_dict, **kwargs):
392        return cls(name, data_list, tuple(data_tuple), data_dict)
393
394
395def test_context_serialization():
396    def custom_resolver(name):
397        if name == 'SBKImpl':
398            return SBKImpl
399
400    test_resolvers = [custom_resolver] + cirq.DEFAULT_RESOLVERS
401
402    sbki_empty = SBKImpl('sbki_empty')
403    assert_json_roundtrip_works(sbki_empty, resolvers=test_resolvers)
404
405    sbki_list = SBKImpl('sbki_list', data_list=[sbki_empty, sbki_empty])
406    assert_json_roundtrip_works(sbki_list, resolvers=test_resolvers)
407
408    sbki_tuple = SBKImpl('sbki_tuple', data_tuple=(sbki_list, sbki_list))
409    assert_json_roundtrip_works(sbki_tuple, resolvers=test_resolvers)
410
411    sbki_dict = SBKImpl('sbki_dict', data_dict={'a': sbki_tuple, 'b': sbki_tuple})
412    assert_json_roundtrip_works(sbki_dict, resolvers=test_resolvers)
413
414    sbki_json = str(cirq.to_json(sbki_dict))
415    # There should be exactly one context item for each previous SBKImpl.
416    assert sbki_json.count('"cirq_type": "_SerializedContext"') == 4
417    # There should be exactly two key items for each of sbki_(empty|list|tuple),
418    # plus one for the top-level sbki_dict.
419    assert sbki_json.count('"cirq_type": "_SerializedKey"') == 7
420    # The final object should be a _SerializedKey for sbki_dict.
421    final_obj_idx = sbki_json.rfind('{')
422    final_obj = sbki_json[final_obj_idx : sbki_json.find('}', final_obj_idx) + 1]
423    assert (
424        final_obj
425        == """{
426      "cirq_type": "_SerializedKey",
427      "key": 4
428    }"""
429    )
430
431    list_sbki = [sbki_dict]
432    assert_json_roundtrip_works(list_sbki, resolvers=test_resolvers)
433
434    dict_sbki = {'a': sbki_dict}
435    assert_json_roundtrip_works(dict_sbki, resolvers=test_resolvers)
436
437    assert sbki_list != json_serialization._SerializedKey(sbki_list)
438
439    # Serialization keys have unique suffixes.
440    sbki_other_list = SBKImpl('sbki_list', data_list=[sbki_list])
441    assert_json_roundtrip_works(sbki_other_list, resolvers=test_resolvers)
442
443
444def test_internal_serializer_types():
445    sbki = SBKImpl('test_key')
446    key = 1
447    test_key = json_serialization._SerializedKey(key)
448    test_context = json_serialization._SerializedContext(sbki, 1)
449    test_serialization = json_serialization._ContextualSerialization(sbki)
450
451    key_json = test_key._json_dict_()
452    with pytest.raises(TypeError, match='_from_json_dict_'):
453        _ = json_serialization._SerializedKey._from_json_dict_(**key_json)
454
455    context_json = test_context._json_dict_()
456    with pytest.raises(TypeError, match='_from_json_dict_'):
457        _ = json_serialization._SerializedContext._from_json_dict_(**context_json)
458
459    serialization_json = test_serialization._json_dict_()
460    with pytest.raises(TypeError, match='_from_json_dict_'):
461        _ = json_serialization._ContextualSerialization._from_json_dict_(**serialization_json)
462
463
464# during test setup deprecated submodules are inspected and trigger the
465# deprecation error in testing. It is cleaner to just turn it off than to assert
466# deprecation for each submodule.
467@mock.patch.dict(os.environ, clear='CIRQ_TESTING')
468def _list_public_classes_for_tested_modules():
469    # to remove DeprecationWarning noise during test collection
470    with warnings.catch_warnings():
471        warnings.simplefilter("ignore")
472        return [
473            (mod_spec, o, n)
474            for mod_spec in MODULE_TEST_SPECS
475            for (o, n) in mod_spec.find_classes_that_should_serialize()
476        ]
477
478
479@pytest.mark.parametrize(
480    'mod_spec,cirq_obj_name,cls',
481    _list_public_classes_for_tested_modules(),
482)
483def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls):
484    if cirq_obj_name in mod_spec.tested_elsewhere:
485        pytest.skip("Tested elsewhere.")
486
487    if cirq_obj_name in mod_spec.not_yet_serializable:
488        return pytest.xfail(reason="Not serializable (yet)")
489
490    test_data_path = mod_spec.test_data_path
491    rel_path = test_data_path.relative_to(REPO_ROOT)
492    mod_path = mod_spec.name.replace(".", "/")
493    rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py"
494    json_path = test_data_path / f'{cirq_obj_name}.json'
495    json_path2 = test_data_path / f'{cirq_obj_name}.json_inward'
496    deprecation_deadline = mod_spec.deprecated.get(cirq_obj_name)
497
498    if not json_path.exists() and not json_path2.exists():
499        # coverage: ignore
500        pytest.fail(
501            f"Hello intrepid developer. There is a new public or "
502            f"serializable object named '{cirq_obj_name}' in the module '{mod_spec.name}' "
503            f"that does not have associated test data.\n"
504            f"\n"
505            f"You must create the file\n"
506            f"    {rel_path}/{cirq_obj_name}.json\n"
507            f"and the file\n"
508            f"    {rel_path}/{cirq_obj_name}.repr\n"
509            f"in order to guarantee this public object is, and will "
510            f"remain, serializable.\n"
511            f"\n"
512            f"The content of the .repr file should be the string returned "
513            f"by `repr(obj)` where `obj` is a test {cirq_obj_name} value "
514            f"or list of such values. To get this to work you may need to "
515            f"implement a __repr__ method for {cirq_obj_name}. The repr "
516            f"must be a parsable python expression that evaluates to "
517            f"something equal to `obj`."
518            f"\n"
519            f"The content of the .json file should be the string returned "
520            f"by `cirq.to_json(obj)` where `obj` is the same object or "
521            f"list of test objects.\n"
522            f"To get this to work you likely need "
523            f"to add {cirq_obj_name} to the "
524            f"`_class_resolver_dictionary` method in "
525            f"the {rel_resolver_cache_path} source file. "
526            f"You may also need to add a _json_dict_ method to "
527            f"{cirq_obj_name}. In some cases you will also need to add a "
528            f"_from_json_dict_ class method to the {cirq_obj_name} class."
529            f"\n"
530            f"For more information on JSON serialization, please read the "
531            f"docstring for cirq.protocols.SupportsJSON. If this object or "
532            f"class is not appropriate for serialization, add its name to "
533            f"the `should_not_be_serialized` list in the TestSpec defined in the "
534            f"{rel_path}/spec.py source file."
535        )
536
537    repr_file = test_data_path / f'{cirq_obj_name}.repr'
538    if repr_file.exists() and cls is not None:
539        objs = _eval_repr_data_file(repr_file, deprecation_deadline=deprecation_deadline)
540        if not isinstance(objs, list):
541            objs = [objs]
542
543        for obj in objs:
544            assert type(obj) == cls, (
545                f"Value in {test_data_path}/{cirq_obj_name}.repr must be of "
546                f"exact type {cls}, or a list of instances of that type. But "
547                f"the value (or one of the list entries) had type "
548                f"{type(obj)}.\n"
549                f"\n"
550                f"If using a value of the wrong type is intended, move the "
551                f"value to {test_data_path}/{cirq_obj_name}.repr_inward\n"
552                f"\n"
553                f"Value with wrong type:\n{obj!r}."
554            )
555
556
557def test_to_from_strings():
558    x_json_text = """{
559  "cirq_type": "_PauliX",
560  "exponent": 1.0,
561  "global_shift": 0.0
562}"""
563    assert cirq.to_json(cirq.X) == x_json_text
564    assert cirq.read_json(json_text=x_json_text) == cirq.X
565
566    with pytest.raises(ValueError, match='specify ONE'):
567        cirq.read_json(io.StringIO(), json_text=x_json_text)
568
569
570def test_to_from_json_gzip():
571    a, b = cirq.LineQubit.range(2)
572    test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b))
573    gzip_data = cirq.to_json_gzip(test_circuit)
574    unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data)
575    assert test_circuit == unzip_circuit
576
577    with pytest.raises(ValueError):
578        _ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data)
579    with pytest.raises(ValueError):
580        _ = cirq.read_json_gzip()
581
582
583def _eval_repr_data_file(path: pathlib.Path, deprecation_deadline: Optional[str]):
584    content = path.read_text()
585    ctx_managers: List[contextlib.AbstractContextManager] = [contextlib.suppress()]
586    if deprecation_deadline:
587        # we ignore coverage here, because sometimes there are no deprecations at all in any of the
588        # modules
589        # coverage: ignore
590        ctx_managers = [cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None)]
591
592    for deprecation in TESTED_MODULES.values():
593        if deprecation is not None and deprecation.old_name in content:
594            ctx_managers.append(deprecation.deprecation_assertion)
595
596    imports = {
597        'cirq': cirq,
598        'pd': pd,
599        'sympy': sympy,
600        'np': np,
601        'datetime': datetime,
602    }
603
604    for m in TESTED_MODULES.keys():
605        try:
606            imports[m] = importlib.import_module(m)
607        except ImportError:
608            pass
609
610    with contextlib.ExitStack() as stack:
611        for ctx_manager in ctx_managers:
612            stack.enter_context(ctx_manager)
613        obj = eval(
614            content,
615            imports,
616            {},
617        )
618        return obj
619
620
621def assert_repr_and_json_test_data_agree(
622    mod_spec: ModuleJsonTestSpec,
623    repr_path: pathlib.Path,
624    json_path: pathlib.Path,
625    inward_only: bool,
626    deprecation_deadline: Optional[str],
627):
628    if not repr_path.exists() and not json_path.exists():
629        return
630
631    rel_repr_path = f'{repr_path.relative_to(REPO_ROOT)}'
632    rel_json_path = f'{json_path.relative_to(REPO_ROOT)}'
633
634    try:
635        json_from_file = json_path.read_text()
636        ctx_manager = (
637            cirq.testing.assert_deprecated(deadline=deprecation_deadline, count=None)
638            if deprecation_deadline
639            else contextlib.suppress()
640        )
641        with ctx_manager:
642            json_obj = cirq.read_json(json_text=json_from_file)
643    except ValueError as ex:  # coverage: ignore
644        # coverage: ignore
645        if "Could not resolve type" in str(ex):
646            mod_path = mod_spec.name.replace(".", "/")
647            rel_resolver_cache_path = f"{mod_path}/json_resolver_cache.py"
648            # coverage: ignore
649            pytest.fail(
650                f"{rel_json_path} can't be parsed to JSON.\n"
651                f"Maybe an entry is missing from the "
652                f" `_class_resolver_dictionary` method in {rel_resolver_cache_path}?"
653            )
654        else:
655            raise ValueError(f"deprecation: {deprecation_deadline} - got error: {ex}")
656    except AssertionError as ex:  # coverage: ignore
657        # coverage: ignore
658        raise ex
659    except Exception as ex:  # coverage: ignore
660        # coverage: ignore
661        raise IOError(f'Failed to parse test json data from {rel_json_path}.') from ex
662
663    try:
664        repr_obj = _eval_repr_data_file(repr_path, deprecation_deadline)
665    except Exception as ex:  # coverage: ignore
666        # coverage: ignore
667        raise IOError(f'Failed to parse test repr data from {rel_repr_path}.') from ex
668
669    assert proper_eq(json_obj, repr_obj), (
670        f'The json data from {rel_json_path} did not parse '
671        f'into an object equivalent to the repr data from {rel_repr_path}.\n'
672        f'\n'
673        f'json object: {json_obj!r}\n'
674        f'repr object: {repr_obj!r}\n'
675    )
676
677    if not inward_only:
678        json_from_cirq = cirq.to_json(repr_obj)
679        json_from_cirq_obj = json.loads(json_from_cirq)
680        json_from_file_obj = json.loads(json_from_file)
681
682        assert json_from_cirq_obj == json_from_file_obj, (
683            f'The json produced by cirq no longer agrees with the json in the '
684            f'{rel_json_path} test data file.\n'
685            f'\n'
686            f'You must either fix the cirq code to continue to produce the '
687            f'same output, or you must move the old test data to '
688            f'{rel_json_path}_inward and create a fresh {rel_json_path} file.\n'
689            f'\n'
690            f'test data json:\n'
691            f'{json_from_file}\n'
692            f'\n'
693            f'cirq produced json:\n'
694            f'{json_from_cirq}\n'
695        )
696
697
698@pytest.mark.parametrize(
699    'mod_spec, abs_path',
700    [(m, abs_path) for m in MODULE_TEST_SPECS for abs_path in m.all_test_data_keys()],
701)
702def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str):
703    assert_repr_and_json_test_data_agree(
704        mod_spec=mod_spec,
705        repr_path=pathlib.Path(f'{abs_path}.repr'),
706        json_path=pathlib.Path(f'{abs_path}.json'),
707        inward_only=False,
708        deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)),
709    )
710    assert_repr_and_json_test_data_agree(
711        mod_spec=mod_spec,
712        repr_path=pathlib.Path(f'{abs_path}.repr_inward'),
713        json_path=pathlib.Path(f'{abs_path}.json_inward'),
714        inward_only=True,
715        deprecation_deadline=mod_spec.deprecated.get(os.path.basename(abs_path)),
716    )
717
718
719def test_pathlib_paths(tmpdir):
720    path = pathlib.Path(tmpdir) / 'op.json'
721    cirq.to_json(cirq.X, path)
722    assert cirq.read_json(path) == cirq.X
723
724    gzip_path = pathlib.Path(tmpdir) / 'op.gz'
725    cirq.to_json_gzip(cirq.X, gzip_path)
726    assert cirq.read_json_gzip(gzip_path) == cirq.X
727
728
729def test_json_serializable_dataclass():
730    @cirq.json_serializable_dataclass
731    class MyDC:
732        q: cirq.LineQubit
733        desc: str
734
735    my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
736
737    def custom_resolver(name):
738        if name == 'MyDC':
739            return MyDC
740
741    assert_json_roundtrip_works(
742        my_dc,
743        text_should_be="\n".join(
744            [
745                '{',
746                '  "cirq_type": "MyDC",',
747                '  "q": {',
748                '    "cirq_type": "LineQubit",',
749                '    "x": 4',
750                '  },',
751                '  "desc": "hi mom"',
752                '}',
753            ]
754        ),
755        resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS,
756    )
757
758
759def test_json_serializable_dataclass_parenthesis():
760    @cirq.json_serializable_dataclass()
761    class MyDC:
762        q: cirq.LineQubit
763        desc: str
764
765    def custom_resolver(name):
766        if name == 'MyDC':
767            return MyDC
768
769    my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
770
771    assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS)
772
773
774def test_dataclass_json_dict():
775    @dataclasses.dataclass(frozen=True)
776    class MyDC:
777        q: cirq.LineQubit
778        desc: str
779
780        def _json_dict_(self):
781            return cirq.dataclass_json_dict(self)
782
783    def custom_resolver(name):
784        if name == 'MyDC':
785            return MyDC
786
787    my_dc = MyDC(cirq.LineQubit(4), 'hi mom')
788
789    assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver, *cirq.DEFAULT_RESOLVERS])
790
791
792def test_json_serializable_dataclass_namespace():
793    @cirq.json_serializable_dataclass(namespace='cirq.experiments')
794    class QuantumVolumeParams:
795        width: int
796        depth: int
797        circuit_i: int
798
799    qvp = QuantumVolumeParams(width=5, depth=5, circuit_i=0)
800
801    def custom_resolver(name):
802        if name == 'cirq.experiments.QuantumVolumeParams':
803            return QuantumVolumeParams
804
805    assert_json_roundtrip_works(qvp, resolvers=[custom_resolver] + cirq.DEFAULT_RESOLVERS)
806
807
808def test_numpy_values():
809    assert (
810        cirq.to_json({'value': np.array(1)})
811        == """{
812  "value": 1
813}"""
814    )
815