1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numbers
16from typing import AbstractSet, Any, cast, TYPE_CHECKING, TypeVar
17
18import sympy
19from typing_extensions import Protocol
20
21from cirq import study
22from cirq._doc import doc_private
23
24if TYPE_CHECKING:
25    import cirq
26
27
28T = TypeVar('T')
29
30
31class SupportsParameterization(Protocol):
32    """An object that can be parameterized by Symbols and resolved
33    via a ParamResolver"""
34
35    @doc_private
36    def _is_parameterized_(self: Any) -> bool:
37        """Whether the object is parameterized by any Symbols that require
38        resolution. Returns True if the object has any unresolved Symbols
39        and False otherwise."""
40
41    @doc_private
42    def _parameter_names_(self: Any) -> AbstractSet[str]:
43        """Returns a collection of string names of parameters that require
44        resolution. If _is_parameterized_ is False, the collection is empty.
45        The converse is not necessarily true, because some objects may report
46        that they are parameterized when they contain symbolic constants which
47        need to be evaluated, but no free symbols.
48        """
49
50    @doc_private
51    def _resolve_parameters_(self: T, resolver: 'cirq.ParamResolver', recursive: bool) -> T:
52        """Resolve the parameters in the effect."""
53
54
55class ResolvableValue(Protocol):
56    @doc_private
57    def _resolved_value_(self) -> Any:
58        """Returns a resolved value during parameter resolution.
59
60        Use this to mark a custom type as "resolved", instead of requiring
61        further parsing like we do with Sympy symbols.
62        """
63
64
65def is_parameterized(val: Any) -> bool:
66    """Returns whether the object is parameterized with any Symbols.
67
68    A value is parameterized when it has an `_is_parameterized_` method and
69    that method returns a truthy value, or if the value is an instance of
70    sympy.Basic.
71
72    Returns:
73        True if the gate has any unresolved Symbols
74        and False otherwise. If no implementation of the magic
75        method above exists or if that method returns NotImplemented,
76        this will default to False.
77    """
78    if isinstance(val, sympy.Basic):
79        return True
80    if isinstance(val, numbers.Number):
81        return False
82    if isinstance(val, (list, tuple)):
83        return any(is_parameterized(e) for e in val)
84
85    getter = getattr(val, '_is_parameterized_', None)
86    result = NotImplemented if getter is None else getter()
87
88    if result is not NotImplemented:
89        return result
90
91    return bool(parameter_names(val))
92
93
94def parameter_names(val: Any) -> AbstractSet[str]:
95    """Returns parameter names for this object.
96
97    Args:
98        val: Object for which to find the parameter names.
99        check_symbols: If true, fall back to calling parameter_symbols.
100
101    Returns:
102        A set of parameter names if the object is parameterized. It the object
103        does not implement the _parameter_names_ magic method or that method
104        returns NotImplemented, returns an empty set.
105    """
106    if isinstance(val, sympy.Basic):
107        return {symbol.name for symbol in val.free_symbols}
108    if isinstance(val, numbers.Number):
109        return set()
110    if isinstance(val, (list, tuple)):
111        return {name for e in val for name in parameter_names(e)}
112
113    getter = getattr(val, '_parameter_names_', None)
114    result = NotImplemented if getter is None else getter()
115    if result is not NotImplemented:
116        return result
117
118    return set()
119
120
121def parameter_symbols(val: Any) -> AbstractSet[sympy.Symbol]:
122    """Returns parameter symbols for this object.
123
124    Args:
125        val: Object for which to find the parameter symbols.
126
127    Returns:
128        A set of parameter symbols if the object is parameterized. It the object
129        does not implement the _parameter_symbols_ magic method or that method
130        returns NotImplemented, returns an empty set.
131    """
132    return {sympy.Symbol(name) for name in parameter_names(val)}
133
134
135def resolve_parameters(
136    val: T, param_resolver: 'cirq.ParamResolverOrSimilarType', recursive: bool = True
137) -> T:
138    """Resolves symbol parameters in the effect using the param resolver.
139
140    This function will use the `_resolve_parameters_` magic method
141    of `val` to resolve any Symbols with concrete values from the given
142    parameter resolver.
143
144    Args:
145        val: The object to resolve (e.g. the gate, operation, etc)
146        param_resolver: the object to use for resolving all symbols
147        recursive: if True, resolves parameters recursively over the
148            resolver; otherwise performs a single resolution step.
149
150    Returns:
151        a gate or operation of the same type, but with all Symbols
152        replaced with floats or terminal symbols according to the
153        given ParamResolver. If `val` has no `_resolve_parameters_`
154        method or if it returns NotImplemented, `val` itself is returned.
155        Note that in some cases, such as when directly resolving a sympy
156        Symbol, the return type could differ from the input type; however,
157        for the much more common case of resolving parameters on cirq
158        objects (or if resolving a Union[Symbol, float] instead of just a
159        Symbol), the return type will be the same as val so we reflect
160        that in the type signature of this protocol function.
161
162    Raises:
163        RecursionError if the ParamResolver detects a loop in resolution.
164        ValueError if `recursive=False` is passed to an external
165            _resolve_parameters_ method with no `recursive` parameter.
166    """
167    if not param_resolver:
168        return val
169
170    # Ensure it is a dictionary wrapped in a ParamResolver.
171    param_resolver = study.ParamResolver(param_resolver)
172
173    # Handle special cases for sympy expressions and sequences.
174    # These may not in fact preserve types, but we pretend they do by casting.
175    if isinstance(val, sympy.Basic):
176        return cast(T, param_resolver.value_of(val, recursive))
177    if isinstance(val, (list, tuple)):
178        return cast(T, type(val)(resolve_parameters(e, param_resolver, recursive) for e in val))
179
180    getter = getattr(val, '_resolve_parameters_', None)
181    if getter is None:
182        result = NotImplemented
183    else:
184        result = getter(param_resolver, recursive)
185
186    if result is not NotImplemented:
187        return result
188    else:
189        return val
190
191
192def resolve_parameters_once(val: Any, param_resolver: 'cirq.ParamResolverOrSimilarType'):
193    """Performs a single parameter resolution step using the param resolver."""
194    return resolve_parameters(val, param_resolver, False)
195