1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import math
15from typing import (
16    List,
17    Union,
18    Optional,
19    Iterator,
20    Iterable,
21    Dict,
22    FrozenSet,
23)
24
25import numpy as np
26import sympy
27from cirq_google.api import v2
28
29SUPPORTED_FUNCTIONS_FOR_LANGUAGE: Dict[Optional[str], FrozenSet[str]] = {
30    '': frozenset(),
31    'linear': frozenset({'add', 'mul'}),
32    'exp': frozenset({'add', 'mul', 'pow'}),
33    # None means any. Is used when inferring the language during serialization.
34    None: frozenset({'add', 'mul', 'pow'}),
35}
36MOST_PERMISSIVE_LANGUAGE = 'exp'
37
38SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow)
39
40# Argument types for gates.
41ARG_LIKE = Union[int, float, List[bool], str, sympy.Symbol, sympy.Add, sympy.Mul]
42FLOAT_ARG_LIKE = Union[float, sympy.Symbol, sympy.Add, sympy.Mul]
43
44# Types for comparing floats
45# Includes sympy types.  Needed for arg parsing.
46FLOAT_TYPES = (float, int, sympy.Integer, sympy.Float, sympy.Rational, sympy.NumberSymbol)
47
48# Supported function languages in order from least to most flexible.
49# Clients should use the least flexible language they can, to make it easier
50# to gradually roll out new capabilities to clients and servers.
51LANGUAGE_ORDER = [
52    '',
53    'linear',
54    'exp',
55]
56
57
58def _max_lang(langs: Iterable[str]) -> str:
59    i = max((LANGUAGE_ORDER.index(e) for e in langs), default=0)
60    return LANGUAGE_ORDER[i]
61
62
63def _infer_function_language_from_circuit(value: v2.program_pb2.Circuit) -> str:
64    return _max_lang(
65        {
66            e
67            for moment in value.moments
68            for op in moment.operations
69            for e in _function_languages_from_operation(op)
70        }
71    )
72
73
74def _function_languages_from_operation(value: v2.program_pb2.Operation) -> Iterator[str]:
75    for arg in value.args.values():
76        yield from _function_languages_from_arg(arg)
77
78
79def _function_languages_from_arg(arg_proto: v2.program_pb2.Arg) -> Iterator[str]:
80
81    which = arg_proto.WhichOneof('arg')
82    if which == 'func':
83        if arg_proto.func.type in ['add', 'mul']:
84            yield 'linear'
85            for a in arg_proto.func.args:
86                yield from _function_languages_from_arg(a)
87        if arg_proto.func.type in ['pow']:
88            yield 'exp'
89            for a in arg_proto.func.args:
90                yield from _function_languages_from_arg(a)
91
92
93def float_arg_to_proto(
94    value: ARG_LIKE,
95    *,
96    arg_function_language: Optional[str] = None,
97    out: Optional[v2.program_pb2.FloatArg] = None,
98) -> v2.program_pb2.FloatArg:
99    """Writes an argument value into an FloatArg proto.
100
101    Note that the FloatArg proto is a slimmed down form of the
102    Arg proto, so this proto should only be used when the argument
103    is known to be a float or expression that resolves to a float.
104
105    Args:
106        value: The value to encode.  This must be a float or compatible
107            sympy expression. Strings and repeated booleans are not allowed.
108        arg_function_language: The language to use when encoding functions. If
109            this is set to None, it will be set to the minimal language
110            necessary to support the features that were actually used.
111        out: The proto to write the result into. Defaults to a new instance.
112
113    Returns:
114        The proto that was written into.
115    """
116    msg = v2.program_pb2.FloatArg() if out is None else out
117
118    if isinstance(value, FLOAT_TYPES):
119        msg.float_value = float(value)
120    else:
121        _arg_func_to_proto(value, arg_function_language, msg)
122
123    return msg
124
125
126def arg_to_proto(
127    value: ARG_LIKE,
128    *,
129    arg_function_language: Optional[str] = None,
130    out: Optional[v2.program_pb2.Arg] = None,
131) -> v2.program_pb2.Arg:
132    """Writes an argument value into an Arg proto.
133
134    Args:
135        value: The value to encode.
136        arg_function_language: The language to use when encoding functions. If
137            this is set to None, it will be set to the minimal language
138            necessary to support the features that were actually used.
139        out: The proto to write the result into. Defaults to a new instance.
140
141    Returns:
142        The proto that was written into as well as the `arg_function_language`
143        that was used.
144    """
145    msg = v2.program_pb2.Arg() if out is None else out
146
147    if isinstance(value, FLOAT_TYPES):
148        msg.arg_value.float_value = float(value)
149    elif isinstance(value, str):
150        msg.arg_value.string_value = value
151    elif isinstance(value, (list, tuple, np.ndarray)) and all(
152        isinstance(x, (bool, np.bool_)) for x in value
153    ):
154        # Some protobuf / numpy combinations do not support np.bool_, so cast.
155        msg.arg_value.bool_values.values.extend([bool(x) for x in value])
156    else:
157        _arg_func_to_proto(value, arg_function_language, msg)
158
159    return msg
160
161
162def _arg_func_to_proto(
163    value: ARG_LIKE,
164    arg_function_language: Optional[str],
165    msg: Union[v2.program_pb2.Arg, v2.program_pb2.FloatArg],
166) -> None:
167    def check_support(func_type: str) -> str:
168        if func_type not in supported:
169            lang = repr(arg_function_language) if arg_function_language is not None else '[any]'
170            raise ValueError(
171                f'Function type {func_type!r} not supported by arg_function_language {lang}'
172            )
173        return func_type
174
175    if arg_function_language not in SUPPORTED_FUNCTIONS_FOR_LANGUAGE:
176        raise ValueError(f'Unrecognized arg_function_language: {arg_function_language!r}')
177    supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE[arg_function_language]
178
179    if isinstance(value, sympy.Symbol):
180        msg.symbol = str(value.free_symbols.pop())
181    elif isinstance(value, sympy.Add):
182        msg.func.type = check_support('add')
183        for arg in value.args:
184            arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add())
185    elif isinstance(value, sympy.Mul):
186        msg.func.type = check_support('mul')
187        for arg in value.args:
188            arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add())
189    elif isinstance(value, sympy.Pow):
190        msg.func.type = check_support('pow')
191        for arg in value.args:
192            arg_to_proto(arg, arg_function_language=arg_function_language, out=msg.func.args.add())
193    else:
194        raise ValueError(f'Unrecognized arg type: {type(value)}')
195
196
197# TODO(#3388) Add documentation for Raises.
198# pylint: disable=missing-raises-doc
199def float_arg_from_proto(
200    arg_proto: v2.program_pb2.FloatArg,
201    *,
202    arg_function_language: str,
203    required_arg_name: Optional[str] = None,
204) -> Optional[FLOAT_ARG_LIKE]:
205    """Extracts a python value from an argument value proto.
206
207    This function handles `FloatArg` protos, that are required
208    to be floats or symbolic expressions.
209
210    Args:
211        arg_proto: The proto containing a serialized value.
212        arg_function_language: The `arg_function_language` field from
213            `Program.Language`.
214        required_arg_name: If set to `None`, the method will return `None` when
215            given an unset proto value. If set to a string, the method will
216            instead raise an error complaining that the value is missing in that
217            situation.
218
219    Returns:
220        The deserialized value, or else None if there was no set value and
221        `required_arg_name` was set to `None`.
222    """
223    which = arg_proto.WhichOneof('arg')
224    if which == 'float_value':
225        result = float(arg_proto.float_value)
226        if round(result) == result:
227            result = int(result)
228        return result
229    elif which == 'symbol':
230        return sympy.Symbol(arg_proto.symbol)
231    elif which == 'func':
232        func = _arg_func_from_proto(
233            arg_proto.func,
234            arg_function_language=arg_function_language,
235            required_arg_name=required_arg_name,
236        )
237        if func is None and required_arg_name is not None:
238            raise ValueError(
239                f'Arg {arg_proto.func} could not be processed for {required_arg_name}.'
240            )
241        return func
242    elif which is None:
243        if required_arg_name is not None:
244            raise ValueError(f'Arg {required_arg_name} is missing.')
245        return None
246    else:
247        raise ValueError(f'unrecognized argument type ({which}).')
248
249
250# TODO(#3388) Add documentation for Raises.
251def arg_from_proto(
252    arg_proto: v2.program_pb2.Arg,
253    *,
254    arg_function_language: str,
255    required_arg_name: Optional[str] = None,
256) -> Optional[ARG_LIKE]:
257    """Extracts a python value from an argument value proto.
258
259    Args:
260        arg_proto: The proto containing a serialized value.
261        arg_function_language: The `arg_function_language` field from
262            `Program.Language`.
263        required_arg_name: If set to `None`, the method will return `None` when
264            given an unset proto value. If set to a string, the method will
265            instead raise an error complaining that the value is missing in that
266            situation.
267
268    Returns:
269        The deserialized value, or else None if there was no set value and
270        `required_arg_name` was set to `None`.
271    """
272
273    which = arg_proto.WhichOneof('arg')
274    if which == 'arg_value':
275        arg_value = arg_proto.arg_value
276        which_val = arg_value.WhichOneof('arg_value')
277        if which_val == 'float_value' or which_val == 'double_value':
278            if which_val == 'double_value':
279                result = float(arg_value.double_value)
280            else:
281                result = float(arg_value.float_value)
282            if math.ceil(result) == math.floor(result):
283                result = int(result)
284            return result
285        if which_val == 'bool_values':
286            return list(arg_value.bool_values.values)
287        if which_val == 'string_value':
288            return str(arg_value.string_value)
289        raise ValueError(f'Unrecognized value type: {which_val!r}')
290
291    if which == 'symbol':
292        return sympy.Symbol(arg_proto.symbol)
293
294    if which == 'func':
295        func = _arg_func_from_proto(
296            arg_proto.func,
297            arg_function_language=arg_function_language,
298            required_arg_name=required_arg_name,
299        )
300        if func is not None:
301            return func
302
303    if required_arg_name is not None:
304        raise ValueError(
305            f'{required_arg_name} is missing or has an unrecognized '
306            f'argument type (WhichOneof("arg")={which!r}).'
307        )
308
309    return None
310
311
312# pylint: enable=missing-raises-doc
313def _arg_func_from_proto(
314    func: v2.program_pb2.ArgFunction,
315    *,
316    arg_function_language: str,
317    required_arg_name: Optional[str] = None,
318) -> Optional[ARG_LIKE]:
319    supported = SUPPORTED_FUNCTIONS_FOR_LANGUAGE.get(arg_function_language)
320    if supported is None:
321        raise ValueError(f'Unrecognized arg_function_language: {arg_function_language!r}')
322
323    if func.type not in supported:
324        raise ValueError(
325            f'Unrecognized function type {func.type!r} '
326            f'for arg_function_language={arg_function_language!r}'
327        )
328
329    if func.type == 'add':
330        return sympy.Add(
331            *[
332                arg_from_proto(
333                    a,
334                    arg_function_language=arg_function_language,
335                    required_arg_name='An addition argument',
336                )
337                for a in func.args
338            ]
339        )
340
341    if func.type == 'mul':
342        return sympy.Mul(
343            *[
344                arg_from_proto(
345                    a,
346                    arg_function_language=arg_function_language,
347                    required_arg_name='A multiplication argument',
348                )
349                for a in func.args
350            ]
351        )
352
353    if func.type == 'pow':
354        return sympy.Pow(
355            *[
356                arg_from_proto(
357                    a,
358                    arg_function_language=arg_function_language,
359                    required_arg_name='A power argument',
360                )
361                for a in func.args
362            ]
363        )
364    return None
365