1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Optional, Dict, List 16 17import cirq 18from cirq_google.api.v2 import run_context_pb2 19 20 21def sweep_to_proto( 22 sweep: cirq.Sweep, 23 *, 24 out: Optional[run_context_pb2.Sweep] = None, 25) -> run_context_pb2.Sweep: 26 """Converts a Sweep to v2 protobuf message. 27 28 Args: 29 sweep: The sweep to convert. 30 out: Optional message to be populated. If not given, a new message will 31 be created. 32 33 Returns: 34 Populated sweep protobuf message. 35 """ 36 if out is None: 37 out = run_context_pb2.Sweep() 38 if sweep is cirq.UnitSweep: 39 pass 40 elif isinstance(sweep, cirq.Product): 41 out.sweep_function.function_type = run_context_pb2.SweepFunction.PRODUCT 42 for factor in sweep.factors: 43 sweep_to_proto(factor, out=out.sweep_function.sweeps.add()) 44 elif isinstance(sweep, cirq.Zip): 45 out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP 46 for s in sweep.sweeps: 47 sweep_to_proto(s, out=out.sweep_function.sweeps.add()) 48 elif isinstance(sweep, cirq.Linspace): 49 out.single_sweep.parameter_key = sweep.key 50 out.single_sweep.linspace.first_point = sweep.start 51 out.single_sweep.linspace.last_point = sweep.stop 52 out.single_sweep.linspace.num_points = sweep.length 53 elif isinstance(sweep, cirq.Points): 54 out.single_sweep.parameter_key = sweep.key 55 out.single_sweep.points.points.extend(sweep.points) 56 elif isinstance(sweep, cirq.ListSweep): 57 sweep_dict: Dict[str, List[cirq.TParamVal]] = {} 58 for param_resolver in sweep: 59 for key in param_resolver: 60 if key not in sweep_dict: 61 sweep_dict[key] = [] 62 sweep_dict[key].append(param_resolver.value_of(key)) 63 out.sweep_function.function_type = run_context_pb2.SweepFunction.ZIP 64 for key in sweep_dict: 65 sweep_to_proto(cirq.Points(key, sweep_dict[key]), out=out.sweep_function.sweeps.add()) 66 else: 67 raise ValueError(f'cannot convert to v2 Sweep proto: {sweep}') 68 return out 69 70 71def sweep_from_proto(msg: run_context_pb2.Sweep) -> cirq.Sweep: 72 """Creates a Sweep from a v2 protobuf message.""" 73 which = msg.WhichOneof('sweep') 74 if which is None: 75 return cirq.UnitSweep 76 if which == 'sweep_function': 77 factors = [sweep_from_proto(m) for m in msg.sweep_function.sweeps] 78 func_type = msg.sweep_function.function_type 79 if func_type == run_context_pb2.SweepFunction.PRODUCT: 80 return cirq.Product(*factors) 81 if func_type == run_context_pb2.SweepFunction.ZIP: 82 return cirq.Zip(*factors) 83 84 raise ValueError(f'invalid sweep function type: {func_type}') 85 if which == 'single_sweep': 86 key = msg.single_sweep.parameter_key 87 if msg.single_sweep.WhichOneof('sweep') == 'linspace': 88 return cirq.Linspace( 89 key=key, 90 start=msg.single_sweep.linspace.first_point, 91 stop=msg.single_sweep.linspace.last_point, 92 length=msg.single_sweep.linspace.num_points, 93 ) 94 if msg.single_sweep.WhichOneof('sweep') == 'points': 95 return cirq.Points(key=key, points=msg.single_sweep.points.points) 96 97 raise ValueError(f'single sweep type not set: {msg}') 98 99 # coverage: ignore 100 raise ValueError(f'sweep type not set: {msg}') 101