1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numbers 16from typing import AbstractSet, Any, cast, TYPE_CHECKING, TypeVar 17 18import sympy 19from typing_extensions import Protocol 20 21from cirq import study 22from cirq._doc import doc_private 23 24if TYPE_CHECKING: 25 import cirq 26 27 28T = TypeVar('T') 29 30 31class SupportsParameterization(Protocol): 32 """An object that can be parameterized by Symbols and resolved 33 via a ParamResolver""" 34 35 @doc_private 36 def _is_parameterized_(self: Any) -> bool: 37 """Whether the object is parameterized by any Symbols that require 38 resolution. Returns True if the object has any unresolved Symbols 39 and False otherwise.""" 40 41 @doc_private 42 def _parameter_names_(self: Any) -> AbstractSet[str]: 43 """Returns a collection of string names of parameters that require 44 resolution. If _is_parameterized_ is False, the collection is empty. 45 The converse is not necessarily true, because some objects may report 46 that they are parameterized when they contain symbolic constants which 47 need to be evaluated, but no free symbols. 48 """ 49 50 @doc_private 51 def _resolve_parameters_(self: T, resolver: 'cirq.ParamResolver', recursive: bool) -> T: 52 """Resolve the parameters in the effect.""" 53 54 55class ResolvableValue(Protocol): 56 @doc_private 57 def _resolved_value_(self) -> Any: 58 """Returns a resolved value during parameter resolution. 59 60 Use this to mark a custom type as "resolved", instead of requiring 61 further parsing like we do with Sympy symbols. 62 """ 63 64 65def is_parameterized(val: Any) -> bool: 66 """Returns whether the object is parameterized with any Symbols. 67 68 A value is parameterized when it has an `_is_parameterized_` method and 69 that method returns a truthy value, or if the value is an instance of 70 sympy.Basic. 71 72 Returns: 73 True if the gate has any unresolved Symbols 74 and False otherwise. If no implementation of the magic 75 method above exists or if that method returns NotImplemented, 76 this will default to False. 77 """ 78 if isinstance(val, sympy.Basic): 79 return True 80 if isinstance(val, numbers.Number): 81 return False 82 if isinstance(val, (list, tuple)): 83 return any(is_parameterized(e) for e in val) 84 85 getter = getattr(val, '_is_parameterized_', None) 86 result = NotImplemented if getter is None else getter() 87 88 if result is not NotImplemented: 89 return result 90 91 return bool(parameter_names(val)) 92 93 94def parameter_names(val: Any) -> AbstractSet[str]: 95 """Returns parameter names for this object. 96 97 Args: 98 val: Object for which to find the parameter names. 99 check_symbols: If true, fall back to calling parameter_symbols. 100 101 Returns: 102 A set of parameter names if the object is parameterized. It the object 103 does not implement the _parameter_names_ magic method or that method 104 returns NotImplemented, returns an empty set. 105 """ 106 if isinstance(val, sympy.Basic): 107 return {symbol.name for symbol in val.free_symbols} 108 if isinstance(val, numbers.Number): 109 return set() 110 if isinstance(val, (list, tuple)): 111 return {name for e in val for name in parameter_names(e)} 112 113 getter = getattr(val, '_parameter_names_', None) 114 result = NotImplemented if getter is None else getter() 115 if result is not NotImplemented: 116 return result 117 118 return set() 119 120 121def parameter_symbols(val: Any) -> AbstractSet[sympy.Symbol]: 122 """Returns parameter symbols for this object. 123 124 Args: 125 val: Object for which to find the parameter symbols. 126 127 Returns: 128 A set of parameter symbols if the object is parameterized. It the object 129 does not implement the _parameter_symbols_ magic method or that method 130 returns NotImplemented, returns an empty set. 131 """ 132 return {sympy.Symbol(name) for name in parameter_names(val)} 133 134 135def resolve_parameters( 136 val: T, param_resolver: 'cirq.ParamResolverOrSimilarType', recursive: bool = True 137) -> T: 138 """Resolves symbol parameters in the effect using the param resolver. 139 140 This function will use the `_resolve_parameters_` magic method 141 of `val` to resolve any Symbols with concrete values from the given 142 parameter resolver. 143 144 Args: 145 val: The object to resolve (e.g. the gate, operation, etc) 146 param_resolver: the object to use for resolving all symbols 147 recursive: if True, resolves parameters recursively over the 148 resolver; otherwise performs a single resolution step. 149 150 Returns: 151 a gate or operation of the same type, but with all Symbols 152 replaced with floats or terminal symbols according to the 153 given ParamResolver. If `val` has no `_resolve_parameters_` 154 method or if it returns NotImplemented, `val` itself is returned. 155 Note that in some cases, such as when directly resolving a sympy 156 Symbol, the return type could differ from the input type; however, 157 for the much more common case of resolving parameters on cirq 158 objects (or if resolving a Union[Symbol, float] instead of just a 159 Symbol), the return type will be the same as val so we reflect 160 that in the type signature of this protocol function. 161 162 Raises: 163 RecursionError if the ParamResolver detects a loop in resolution. 164 ValueError if `recursive=False` is passed to an external 165 _resolve_parameters_ method with no `recursive` parameter. 166 """ 167 if not param_resolver: 168 return val 169 170 # Ensure it is a dictionary wrapped in a ParamResolver. 171 param_resolver = study.ParamResolver(param_resolver) 172 173 # Handle special cases for sympy expressions and sequences. 174 # These may not in fact preserve types, but we pretend they do by casting. 175 if isinstance(val, sympy.Basic): 176 return cast(T, param_resolver.value_of(val, recursive)) 177 if isinstance(val, (list, tuple)): 178 return cast(T, type(val)(resolve_parameters(e, param_resolver, recursive) for e in val)) 179 180 getter = getattr(val, '_resolve_parameters_', None) 181 if getter is None: 182 result = NotImplemented 183 else: 184 result = getter(param_resolver, recursive) 185 186 if result is not NotImplemented: 187 return result 188 else: 189 return val 190 191 192def resolve_parameters_once(val: Any, param_resolver: 'cirq.ParamResolverOrSimilarType'): 193 """Performs a single parameter resolution step using the param resolver.""" 194 return resolve_parameters(val, param_resolver, False) 195