1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import math 15from typing import ( 16 List, 17 Union, 18 Optional, 19 Iterator, 20 Iterable, 21 Dict, 22 FrozenSet, 23) 24 25import numpy as np 26import sympy 27from cirq_google.api import v2 28 29SUPPORTED_FUNCTIONS_FOR_LANGUAGE: Dict[Optional[str], FrozenSet[str]] = { 30 '': frozenset(), 31 'linear': frozenset({'add', 'mul'}), 32 'exp': frozenset({'add', 'mul', 'pow'}), 33 # None means any. Is used when inferring the language during serialization. 34 None: frozenset({'add', 'mul', 'pow'}), 35} 36MOST_PERMISSIVE_LANGUAGE = 'exp' 37 38SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow) 39 40# Argument types for gates. 41ARG_LIKE = Union[int, float, List[bool], str, sympy.Symbol, sympy.Add, sympy.Mul] 42FLOAT_ARG_LIKE = Union[float, sympy.Symbol, sympy.Add, sympy.Mul] 43 44# Types for comparing floats 45# Includes sympy types. Needed for arg parsing. 46FLOAT_TYPES = (float, int, sympy.Integer, sympy.Float, sympy.Rational, sympy.NumberSymbol) 47 48# Supported function languages in order from least to most flexible. 49# Clients should use the least flexible language they can, to make it easier 50# to gradually roll out new capabilities to clients and servers. 51LANGUAGE_ORDER = [ 52 '', 53 'linear', 54 'exp', 55] 56 57 58def _max_lang(langs: Iterable[str]) -> str: 59 i = max((LANGUAGE_ORDER.index(e) for e in langs), default=0) 60 return LANGUAGE_ORDER[i] 61 62 63def _infer_function_language_from_circuit(value: v2.program_pb2.Circuit) -> str: 64 return _max_lang( 65 { 66 e 67 for moment in value.moments 68 for op in moment.operations 69 for e in _function_languages_from_operation(op) 70 } 71 ) 72 73 74def _function_languages_from_operation(value: v2.program_pb2.Operation) -> Iterator[str]: 75 for arg in value.args.values(): 76 yield from _function_languages_from_arg(arg) 77 78 79def _function_languages_from_arg(arg_proto: v2.program_pb2.Arg) -> Iterator[str]: 80 81 which = arg_proto.WhichOneof('arg') 82 if which == 'func': 83 if arg_proto.func.type in ['add', 'mul']: 84 yield 'linear' 85 for a in arg_proto.func.args: 86 yield from _function_languages_from_arg(a) 87 if arg_proto.func.type in ['pow']: 88 yield 'exp' 89 for a in arg_proto.func.args: 90 yield from _function_languages_from_arg(a) 91 92 93def float_arg_to_proto( 94 value: ARG_LIKE, 95 *, 96 arg_function_language: Optional[str] = None, 97 out: Optional[v2.program_pb2.FloatArg] = None, 98) -> v2.program_pb2.FloatArg: 99 """Writes an argument value into an FloatArg proto. 100 101 Note that the FloatArg proto is a slimmed down form of the 102 Arg proto, so this proto should only be used when the argument 103 is known to be a float or expression that resolves to a float. 104 105 Args: 106 value: The value to encode. This must be a float or compatible 107 sympy expression. Strings and repeated booleans are not allowed. 108 arg_function_language: The language to use when encoding functions. If 109 this is set to None, it will be set to the minimal language 110 necessary to support the features that were actually used. 111 out: The proto to write the result into. Defaults to a new instance. 112 113 Returns: 114 The proto that was written into. 115 """ 116 msg = v2.program_pb2.FloatArg() if out is None else out 117 118 if isinstance(value, FLOAT_TYPES): 119 msg.float_value = float(value) 120 else: 121 _arg_func_to_proto(value, arg_function_language, msg) 122 123 return msg 124 125 126def arg_to_proto( 127 value: ARG_LIKE, 128 *, 129 arg_function_language: Optional[str] = None, 130 out: Optional[v2.program_pb2.Arg] = None, 131) -> v2.program_pb2.Arg: 132 """Writes an argument value into an Arg proto. 133 134 Args: 135 value: The value to encode. 136 arg_function_language: The language to use when encoding functions. If 137 this is set to None, it will be set to the minimal language 138 necessary to support the features that were actually used. 139 out: The proto to write the result into. Defaults to a new instance. 140 141 Returns: 142 The proto that was written into as well as the `arg_function_language` 143 that was used. 144 """ 145 msg = v2.program_pb2.Arg() if out is None else out 146 147 if isinstance(value, FLOAT_TYPES): 148 msg.arg_value.float_value = float(value) 149 elif isinstance(value, str): 150 msg.arg_value.string_value = value 151 elif isinstance(value, (list, tuple, np.ndarray)) and all( 152 isinstance(x, (bool, np.bool_)) for x in value 153 ): 154 # Some protobuf / numpy combinations do not support np.bool_, so cast. 155 msg.arg_value.bool_values.values.extend([bool(x) for x in value]) 156 else: 157 _arg_func_to_proto(value, arg_function_language, msg) 158 159 return msg 160 161 162def _arg_func_to_proto( 163 value: ARG_LIKE, 164 arg_function_language: Optional[str], 165 msg: Union[v2.program_pb2.Arg, v2.program_pb2.FloatArg], 166) -> None: 167 def check_support(func_type: str) -> str: 168 if func_type not in supported: 169 lang = repr(arg_function_language) if arg_function_language is not None else '[any]' 170 raise ValueError( 171 f'Function type {func_type!r} not supported by arg_function_language {lang}' 172 ) 173 return func_type 174 175 if arg_function_language not in SUPPORTED_FUNCTIONS_FOR_LANGUAGE: 176 raise ValueError(f'Unrecognized arg_function_language: {arg_function_language!r}') 177 supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE[arg_function_language] 178 179 if isinstance(value, sympy.Symbol): 180 msg.symbol = str(value.free_symbols.pop()) 181 elif isinstance(value, sympy.Add): 182 msg.func.type = check_support('add') 183 for arg in value.args: 184 arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add()) 185 elif isinstance(value, sympy.Mul): 186 msg.func.type = check_support('mul') 187 for arg in value.args: 188 arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add()) 189 elif isinstance(value, sympy.Pow): 190 msg.func.type = check_support('pow') 191 for arg in value.args: 192 arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add()) 193 else: 194 raise ValueError(f'Unrecognized arg type: {type(value)}') 195 196 197# TODO(#3388) Add documentation for Raises. 198# pylint: disable=missing-raises-doc 199def float_arg_from_proto( 200 arg_proto: v2.program_pb2.FloatArg, 201 *, 202 arg_function_language: str, 203 required_arg_name: Optional[str] = None, 204) -> Optional[FLOAT_ARG_LIKE]: 205 """Extracts a python value from an argument value proto. 206 207 This function handles `FloatArg` protos, that are required 208 to be floats or symbolic expressions. 209 210 Args: 211 arg_proto: The proto containing a serialized value. 212 arg_function_language: The `arg_function_language` field from 213 `Program.Language`. 214 required_arg_name: If set to `None`, the method will return `None` when 215 given an unset proto value. If set to a string, the method will 216 instead raise an error complaining that the value is missing in that 217 situation. 218 219 Returns: 220 The deserialized value, or else None if there was no set value and 221 `required_arg_name` was set to `None`. 222 """ 223 which = arg_proto.WhichOneof('arg') 224 if which == 'float_value': 225 result = float(arg_proto.float_value) 226 if round(result) == result: 227 result = int(result) 228 return result 229 elif which == 'symbol': 230 return sympy.Symbol(arg_proto.symbol) 231 elif which == 'func': 232 func = _arg_func_from_proto( 233 arg_proto.func, 234 arg_function_language=arg_function_language, 235 required_arg_name=required_arg_name, 236 ) 237 if func is None and required_arg_name is not None: 238 raise ValueError( 239 f'Arg {arg_proto.func} could not be processed for {required_arg_name}.' 240 ) 241 return func 242 elif which is None: 243 if required_arg_name is not None: 244 raise ValueError(f'Arg {required_arg_name} is missing.') 245 return None 246 else: 247 raise ValueError(f'unrecognized argument type ({which}).') 248 249 250# TODO(#3388) Add documentation for Raises. 251def arg_from_proto( 252 arg_proto: v2.program_pb2.Arg, 253 *, 254 arg_function_language: str, 255 required_arg_name: Optional[str] = None, 256) -> Optional[ARG_LIKE]: 257 """Extracts a python value from an argument value proto. 258 259 Args: 260 arg_proto: The proto containing a serialized value. 261 arg_function_language: The `arg_function_language` field from 262 `Program.Language`. 263 required_arg_name: If set to `None`, the method will return `None` when 264 given an unset proto value. If set to a string, the method will 265 instead raise an error complaining that the value is missing in that 266 situation. 267 268 Returns: 269 The deserialized value, or else None if there was no set value and 270 `required_arg_name` was set to `None`. 271 """ 272 273 which = arg_proto.WhichOneof('arg') 274 if which == 'arg_value': 275 arg_value = arg_proto.arg_value 276 which_val = arg_value.WhichOneof('arg_value') 277 if which_val == 'float_value' or which_val == 'double_value': 278 if which_val == 'double_value': 279 result = float(arg_value.double_value) 280 else: 281 result = float(arg_value.float_value) 282 if math.ceil(result) == math.floor(result): 283 result = int(result) 284 return result 285 if which_val == 'bool_values': 286 return list(arg_value.bool_values.values) 287 if which_val == 'string_value': 288 return str(arg_value.string_value) 289 raise ValueError(f'Unrecognized value type: {which_val!r}') 290 291 if which == 'symbol': 292 return sympy.Symbol(arg_proto.symbol) 293 294 if which == 'func': 295 func = _arg_func_from_proto( 296 arg_proto.func, 297 arg_function_language=arg_function_language, 298 required_arg_name=required_arg_name, 299 ) 300 if func is not None: 301 return func 302 303 if required_arg_name is not None: 304 raise ValueError( 305 f'{required_arg_name} is missing or has an unrecognized ' 306 f'argument type (WhichOneof("arg")={which!r}).' 307 ) 308 309 return None 310 311 312# pylint: enable=missing-raises-doc 313def _arg_func_from_proto( 314 func: v2.program_pb2.ArgFunction, 315 *, 316 arg_function_language: str, 317 required_arg_name: Optional[str] = None, 318) -> Optional[ARG_LIKE]: 319 supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE.get(arg_function_language) 320 if supported is None: 321 raise ValueError(f'Unrecognized arg_function_language: {arg_function_language!r}') 322 323 if func.type not in supported: 324 raise ValueError( 325 f'Unrecognized function type {func.type!r} ' 326 f'for arg_function_language={arg_function_language!r}' 327 ) 328 329 if func.type == 'add': 330 return sympy.Add( 331 *[ 332 arg_from_proto( 333 a, 334 arg_function_language=arg_function_language, 335 required_arg_name='An addition argument', 336 ) 337 for a in func.args 338 ] 339 ) 340 341 if func.type == 'mul': 342 return sympy.Mul( 343 *[ 344 arg_from_proto( 345 a, 346 arg_function_language=arg_function_language, 347 required_arg_name='A multiplication argument', 348 ) 349 for a in func.args 350 ] 351 ) 352 353 if func.type == 'pow': 354 return sympy.Pow( 355 *[ 356 arg_from_proto( 357 a, 358 arg_function_language=arg_function_language, 359 required_arg_name='A power argument', 360 ) 361 for a in func.args 362 ] 363 ) 364 return None 365