1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, Dict, List
16
17import cirq
18from cirq_google.api.v2 import run_context_pb2
19
20
21def sweep_to_proto(
22    sweep: cirq.Sweep,
23    *,
24    out: Optional[run_context_pb2.Sweep] = None,
25) -> run_context_pb2.Sweep:
26    """Converts a Sweep to v2 protobuf message.
27
28    Args:
29        sweep: The sweep to convert.
30        out: Optional message to be populated. If not given, a new message will
31            be created.
32
33    Returns:
34        Populated sweep protobuf message.
35    """
36    if out is None:
37        out = run_context_pb2.Sweep()
38    if sweep is cirq.UnitSweep:
39        pass
40    elif isinstance(sweep, cirq.Product):
41        out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT
42        for factor in sweep.factors:
43            sweep_to_proto(factor, out=out.sweep_function.sweeps.add())
44    elif isinstance(sweep, cirq.Zip):
45        out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
46        for s in sweep.sweeps:
47            sweep_to_proto(s, out=out.sweep_function.sweeps.add())
48    elif isinstance(sweep, cirq.Linspace):
49        out.single_sweep.parameter_key = sweep.key
50        out.single_sweep.linspace.first_point = sweep.start
51        out.single_sweep.linspace.last_point = sweep.stop
52        out.single_sweep.linspace.num_points = sweep.length
53    elif isinstance(sweep, cirq.Points):
54        out.single_sweep.parameter_key = sweep.key
55        out.single_sweep.points.points.extend(sweep.points)
56    elif isinstance(sweep, cirq.ListSweep):
57        sweep_dict: Dict[str, List[cirq.TParamVal]] = {}
58        for param_resolver in sweep:
59            for key in param_resolver:
60                if key not in sweep_dict:
61                    sweep_dict[key] = []
62                sweep_dict[key].append(param_resolver.value_of(key))
63        out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
64        for key in sweep_dict:
65            sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())
66    else:
67        raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
68    return out
69
70
71def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
72    """Creates a Sweep from a v2 protobuf message."""
73    which = msg.WhichOneof('sweep')
74    if which is None:
75        return cirq.UnitSweep
76    if which == 'sweep_function':
77        factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps]
78        func_type = msg.sweep_function.function_type
79        if func_type == run_context_pb2.SweepFunction.PRODUCT:
80            return cirq.Product(*factors)
81        if func_type == run_context_pb2.SweepFunction.ZIP:
82            return cirq.Zip(*factors)
83
84        raise ValueError(f'invalid sweep function type: {func_type}')
85    if which == 'single_sweep':
86        key = msg.single_sweep.parameter_key
87        if msg.single_sweep.WhichOneof('sweep') == 'linspace':
88            return cirq.Linspace(
89                key=key,
90                start=msg.single_sweep.linspace.first_point,
91                stop=msg.single_sweep.linspace.last_point,
92                length=msg.single_sweep.linspace.num_points,
93            )
94        if msg.single_sweep.WhichOneof('sweep') == 'points':
95            return cirq.Points(key=key, points=msg.single_sweep.points.points)
96
97        raise ValueError(f'single sweep type not set: {msg}')
98
99    # coverage: ignore
100    raise ValueError(f'sweep type not set: {msg}')
101