1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, Dict, List
16
17import cirq
18from cirq_google.api.v2 import run_context_pb2
19
20
21# TODO(#3388) Add documentation for Raises.
22# pylint: disable=missing-raises-doc
23def sweep_to_proto(
24    sweep: cirq.Sweep,
25    *,
26    out: Optional[run_context_pb2.Sweep] = None,
27) -> run_context_pb2.Sweep:
28    """Converts a Sweep to v2 protobuf message.
29
30    Args:
31        sweep: The sweep to convert.
32        out: Optional message to be populated. If not given, a new message will
33            be created.
34
35    Returns:
36        Populated sweep protobuf message.
37    """
38    if out is None:
39        out = run_context_pb2.Sweep()
40    if sweep is cirq.UnitSweep:
41        pass
42    elif isinstance(sweep, cirq.Product):
43        out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT
44        for factor in sweep.factors:
45            sweep_to_proto(factor, out=out.sweep_function.sweeps.add())
46    elif isinstance(sweep, cirq.Zip):
47        out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
48        for s in sweep.sweeps:
49            sweep_to_proto(s, out=out.sweep_function.sweeps.add())
50    elif isinstance(sweep, cirq.Linspace):
51        out.single_sweep.parameter_key = sweep.key
52        out.single_sweep.linspace.first_point = sweep.start
53        out.single_sweep.linspace.last_point = sweep.stop
54        out.single_sweep.linspace.num_points = sweep.length
55    elif isinstance(sweep, cirq.Points):
56        out.single_sweep.parameter_key = sweep.key
57        out.single_sweep.points.points.extend(sweep.points)
58    elif isinstance(sweep, cirq.ListSweep):
59        sweep_dict: Dict[str, List[cirq.TParamVal]] = {}
60        for param_resolver in sweep:
61            for key in param_resolver:
62                if key not in sweep_dict:
63                    sweep_dict[key] = []
64                sweep_dict[key].append(param_resolver.value_of(key))
65        out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP
66        for key in sweep_dict:
67            sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add())
68    else:
69        raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}')
70    return out
71
72
73# pylint: enable=missing-raises-doc
74def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep:
75    """Creates a Sweep from a v2 protobuf message."""
76    which = msg.WhichOneof('sweep')
77    if which is None:
78        return cirq.UnitSweep
79    if which == 'sweep_function':
80        factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps]
81        func_type = msg.sweep_function.function_type
82        if func_type == run_context_pb2.SweepFunction.PRODUCT:
83            return cirq.Product(*factors)
84        if func_type == run_context_pb2.SweepFunction.ZIP:
85            return cirq.Zip(*factors)
86
87        raise ValueError(f'invalid sweep function type: {func_type}')
88    if which == 'single_sweep':
89        key = msg.single_sweep.parameter_key
90        if msg.single_sweep.WhichOneof('sweep') == 'linspace':
91            return cirq.Linspace(
92                key=key,
93                start=msg.single_sweep.linspace.first_point,
94                stop=msg.single_sweep.linspace.last_point,
95                length=msg.single_sweep.linspace.num_points,
96            )
97        if msg.single_sweep.WhichOneof('sweep') == 'points':
98            return cirq.Points(key=key, points=msg.single_sweep.points.points)
99
100        raise ValueError(f'single sweep type not set: {msg}')
101
102    # coverage: ignore
103    raise ValueError(f'sweep type not set: {msg}')
104