1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Optional, Dict, List 16 17import cirq 18from cirq_google.api.v2 import run_context_pb2 19 20 21# TODO(#3388) Add documentation for Raises. 22# pylint: disable=missing-raises-doc 23def sweep_to_proto( 24 sweep: cirq.Sweep, 25 *, 26 out: Optional[run_context_pb2.Sweep] = None, 27) -> run_context_pb2.Sweep: 28 """Converts a Sweep to v2 protobuf message. 29 30 Args: 31 sweep: The sweep to convert. 32 out: Optional message to be populated. If not given, a new message will 33 be created. 34 35 Returns: 36 Populated sweep protobuf message. 37 """ 38 if out is None: 39 out = run_context_pb2.Sweep() 40 if sweep is cirq.UnitSweep: 41 pass 42 elif isinstance(sweep, cirq.Product): 43 out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT 44 for factor in sweep.factors: 45 sweep_to_proto(factor, out=out.sweep_function.sweeps.add()) 46 elif isinstance(sweep, cirq.Zip): 47 out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP 48 for s in sweep.sweeps: 49 sweep_to_proto(s, out=out.sweep_function.sweeps.add()) 50 elif isinstance(sweep, cirq.Linspace): 51 out.single_sweep.parameter_key = sweep.key 52 out.single_sweep.linspace.first_point = sweep.start 53 out.single_sweep.linspace.last_point = sweep.stop 54 out.single_sweep.linspace.num_points = sweep.length 55 elif isinstance(sweep, cirq.Points): 56 out.single_sweep.parameter_key = sweep.key 57 out.single_sweep.points.points.extend(sweep.points) 58 elif isinstance(sweep, cirq.ListSweep): 59 sweep_dict: Dict[str, List[cirq.TParamVal]] = {} 60 for param_resolver in sweep: 61 for key in param_resolver: 62 if key not in sweep_dict: 63 sweep_dict[key] = [] 64 sweep_dict[key].append(param_resolver.value_of(key)) 65 out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP 66 for key in sweep_dict: 67 sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add()) 68 else: 69 raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}') 70 return out 71 72 73# pylint: enable=missing-raises-doc 74def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: 75 """Creates a Sweep from a v2 protobuf message.""" 76 which = msg.WhichOneof('sweep') 77 if which is None: 78 return cirq.UnitSweep 79 if which == 'sweep_function': 80 factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps] 81 func_type = msg.sweep_function.function_type 82 if func_type == run_context_pb2.SweepFunction.PRODUCT: 83 return cirq.Product(*factors) 84 if func_type == run_context_pb2.SweepFunction.ZIP: 85 return cirq.Zip(*factors) 86 87 raise ValueError(f'invalid sweep function type: {func_type}') 88 if which == 'single_sweep': 89 key = msg.single_sweep.parameter_key 90 if msg.single_sweep.WhichOneof('sweep') == 'linspace': 91 return cirq.Linspace( 92 key=key, 93 start=msg.single_sweep.linspace.first_point, 94 stop=msg.single_sweep.linspace.last_point, 95 length=msg.single_sweep.linspace.num_points, 96 ) 97 if msg.single_sweep.WhichOneof('sweep') == 'points': 98 return cirq.Points(key=key, points=msg.single_sweep.points.points) 99 100 raise ValueError(f'single sweep type not set: {msg}') 101 102 # coverage: ignore 103 raise ValueError(f'sweep type not set: {msg}') 104