1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import cast
15
16import cirq
17from cirq.study import sweeps
18from cirq_google.api.v1 import params_pb2
19
20
21def sweep_to_proto(sweep: cirq.Sweep, repetitions: int = 1) -> params_pb2.ParameterSweep:
22    """Converts sweep into an equivalent protobuf representation."""
23    product_sweep = None
24    if sweep != cirq.UnitSweep:
25        sweep = _to_zip_product(sweep)
26        product_sweep = params_pb2.ProductSweep(
27            factors=[_sweep_zip_to_proto(cast(cirq.Zip, factor)) for factor in sweep.factors]
28        )
29    msg = params_pb2.ParameterSweep(repetitions=repetitions, sweep=product_sweep)
30    return msg
31
32
33def _to_zip_product(sweep: cirq.Sweep) -> cirq.Product:
34    """Converts sweep to a product of zips of single sweeps, if possible."""
35    if not isinstance(sweep, cirq.Product):
36        sweep = cirq.Product(sweep)
37    if not all(isinstance(f, cirq.Zip) for f in sweep.factors):
38        factors = [f if isinstance(f, cirq.Zip) else cirq.Zip(f) for f in sweep.factors]
39        sweep = cirq.Product(*factors)
40    for factor in sweep.factors:
41        for term in cast(cirq.Zip, factor).sweeps:
42            if not isinstance(term, sweeps.SingleSweep):
43                raise ValueError(f'cannot convert to zip-product form: {sweep}')
44    return sweep
45
46
47def _sweep_zip_to_proto(sweep: cirq.Zip) -> params_pb2.ZipSweep:
48    sweep_list = [_single_param_sweep_to_proto(cast(sweeps.SingleSweep, s)) for s in sweep.sweeps]
49    return params_pb2.ZipSweep(sweeps=sweep_list)
50
51
52def _single_param_sweep_to_proto(sweep: sweeps.SingleSweep) -> params_pb2.SingleSweep:
53    if isinstance(sweep, cirq.Linspace):
54        return params_pb2.SingleSweep(
55            parameter_key=sweep.key,
56            linspace=params_pb2.Linspace(
57                first_point=sweep.start, last_point=sweep.stop, num_points=sweep.length
58            ),
59        )
60    elif isinstance(sweep, cirq.Points):
61        return params_pb2.SingleSweep(
62            parameter_key=sweep.key, points=params_pb2.Points(points=sweep.points)
63        )
64    else:
65        raise ValueError(f'invalid single-parameter sweep: {sweep}')
66
67
68def sweep_from_proto(param_sweep: params_pb2.ParameterSweep) -> cirq.Sweep:
69    if param_sweep.HasField('sweep') and len(param_sweep.sweep.factors) > 0:
70        return cirq.Product(
71            *[_sweep_from_param_sweep_zip_proto(f) for f in param_sweep.sweep.factors]
72        )
73    return cirq.UnitSweep
74
75
76def _sweep_from_param_sweep_zip_proto(param_sweep_zip: params_pb2.ZipSweep) -> cirq.Sweep:
77    if len(param_sweep_zip.sweeps) > 0:
78        return cirq.Zip(
79            *[_sweep_from_single_param_sweep_proto(sweep) for sweep in param_sweep_zip.sweeps]
80        )
81    return cirq.UnitSweep
82
83
84def _sweep_from_single_param_sweep_proto(
85    single_param_sweep: params_pb2.SingleSweep,
86) -> cirq.Sweep:
87    key = single_param_sweep.parameter_key
88    if single_param_sweep.HasField('points'):
89        points = single_param_sweep.points
90        return cirq.Points(key, list(points.points))
91    if single_param_sweep.HasField('linspace'):
92        sl = single_param_sweep.linspace
93        return cirq.Linspace(key, sl.first_point, sl.last_point, sl.num_points)
94
95    raise ValueError('Single param sweep type undefined')
96