1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import ( 15 Any, 16 cast, 17 Dict, 18 Iterable, 19 Iterator, 20 List, 21 overload, 22 Sequence, 23 TYPE_CHECKING, 24 Tuple, 25 Union, 26) 27 28import abc 29import collections 30import itertools 31import sympy 32 33from cirq._doc import document 34from cirq.study import resolver 35 36if TYPE_CHECKING: 37 import cirq 38 39Params = Iterable[Tuple['cirq.TParamKey', 'cirq.TParamVal']] 40ProductOrZipSweepLike = Dict['cirq.TParamKey', Union['cirq.TParamVal', Sequence['cirq.TParamVal']]] 41 42 43def _check_duplicate_keys(sweeps): 44 keys = set() 45 for sweep in sweeps: 46 if any(key in keys for key in sweep.keys): 47 raise ValueError('duplicate keys') 48 keys.update(sweep.keys) 49 50 51class Sweep(metaclass=abc.ABCMeta): 52 """A sweep is an iterator over ParamResolvers. 53 54 A ParamResolver assigns values to Symbols. For sweeps, each ParamResolver 55 must specify the same Symbols that are assigned. So a sweep is a way to 56 iterate over a set of different values for a fixed set of Symbols. This is 57 useful for a circuit, where there are a fixed set of Symbols, and you want 58 to iterate over an assignment of all values to all symbols. 59 60 For example, a sweep can explicitly assign a set of equally spaced points 61 between two endpoints using a Linspace, 62 sweep = Linspace("angle", start=0.0, end=2.0, length=10) 63 This can then be used with a circuit that has an 'angle' sympy.Symbol to 64 run simulations multiple simulations, one for each of the values in the 65 sweep 66 result = simulator.run_sweep(program=circuit, params=sweep) 67 68 Sweeps support Cartesian and Zip products using the '*' and '+' operators, 69 see the Product and Zip documentation. 70 """ 71 72 def __mul__(self, other: 'Sweep') -> 'Sweep': 73 factors = [] # type: List[Sweep] 74 if isinstance(self, Product): 75 factors.extend(self.factors) 76 else: 77 factors.append(self) 78 if isinstance(other, Product): 79 factors.extend(other.factors) 80 elif isinstance(other, Sweep): 81 factors.append(other) 82 else: 83 raise TypeError(f'cannot multiply sweep and {type(other)}') 84 return Product(*factors) 85 86 def __add__(self, other: 'Sweep') -> 'Sweep': 87 sweeps = [] # type: List[Sweep] 88 if isinstance(self, Zip): 89 sweeps.extend(self.sweeps) 90 else: 91 sweeps.append(self) 92 if isinstance(other, Zip): 93 sweeps.extend(other.sweeps) 94 elif isinstance(other, Sweep): 95 sweeps.append(other) 96 else: 97 raise TypeError(f'cannot add sweep and {type(other)}') 98 return Zip(*sweeps) 99 100 @abc.abstractmethod 101 def __eq__(self, other): 102 pass 103 104 def __ne__(self, other): 105 return not self == other 106 107 @property 108 @abc.abstractmethod 109 def keys(self) -> List['cirq.TParamKey']: 110 """The keys for the all of the sympy.Symbols that are resolved.""" 111 112 @abc.abstractmethod 113 def __len__(self) -> int: 114 pass 115 116 def __iter__(self) -> Iterator[resolver.ParamResolver]: 117 for params in self.param_tuples(): 118 yield resolver.ParamResolver(collections.OrderedDict(params)) 119 120 # pylint: disable=function-redefined 121 @overload 122 def __getitem__(self, val: int) -> resolver.ParamResolver: 123 pass 124 125 @overload 126 def __getitem__(self, val: slice) -> 'Sweep': 127 pass 128 129 def __getitem__(self, val): 130 n = len(self) 131 if isinstance(val, int): 132 if val < -n or val >= n: 133 raise IndexError(f'sweep index out of range: {val}') 134 if val < 0: 135 val += n 136 return next(itertools.islice(self, val, val + 1)) 137 if not isinstance(val, slice): 138 raise TypeError(f'Sweep indices must be either int or slices, not {type(val)}') 139 140 inds_map: Dict[int, int] = { 141 sweep_i: slice_i for slice_i, sweep_i in enumerate(range(n)[val]) 142 } 143 results = [resolver.ParamResolver()] * len(inds_map) 144 for i, item in enumerate(self): 145 if i in inds_map: 146 results[inds_map[i]] = item 147 148 return ListSweep(results) 149 150 # pylint: enable=function-redefined 151 152 @abc.abstractmethod 153 def param_tuples(self) -> Iterator[Params]: 154 """An iterator over (key, value) pairs assigning Symbol key to value.""" 155 156 def __str__(self) -> str: 157 length = len(self) 158 max_show = 10 159 # Show a maximum of max_show entries with an ellipsis in the middle 160 if length > max_show: 161 beginning_len = max_show - max_show // 2 162 else: 163 beginning_len = max_show 164 end_len = max_show - beginning_len 165 lines = ['Sweep:'] 166 lines.extend(str(dict(r.param_dict)) for r in itertools.islice(self, beginning_len)) 167 if end_len > 0: 168 lines.append('...') 169 lines.extend( 170 str(dict(r.param_dict)) for r in itertools.islice(self, length - end_len, length) 171 ) 172 return '\n'.join(lines) 173 174 175class _Unit(Sweep): 176 """A sweep with a single element that assigns no parameter values. 177 178 This is useful as a base sweep, instead of special casing None. 179 """ 180 181 def __eq__(self, other): 182 if not isinstance(other, self.__class__): 183 return NotImplemented 184 return True 185 186 @property 187 def keys(self) -> List['cirq.TParamKey']: 188 return [] 189 190 def __len__(self) -> int: 191 return 1 192 193 def param_tuples(self) -> Iterator[Params]: 194 yield () 195 196 def __repr__(self) -> str: 197 return 'cirq.UnitSweep' 198 199 200UnitSweep = _Unit() 201document(UnitSweep, """The singleton sweep with no parameters.""") 202 203 204class Product(Sweep): 205 """Cartesian product of one or more sweeps. 206 207 If one sweep assigns 'a' to the values 0, 1, 2, and the second sweep 208 assigns 'b' to the values 2, 3, then the product is a sweep that 209 assigns the tuple ('a','b') to all possible combinations of these 210 assignments: (0, 2), (1, 2), (2, 2), (0, 3), (1, 3), (2, 3). 211 """ 212 213 def __init__(self, *factors: Sweep) -> None: 214 _check_duplicate_keys(factors) 215 self.factors = factors 216 217 def __eq__(self, other): 218 if not isinstance(other, Product): 219 return NotImplemented 220 return self.factors == other.factors 221 222 def __hash__(self): 223 return hash(tuple(self.factors)) 224 225 @property 226 def keys(self) -> List['cirq.TParamKey']: 227 return sum((factor.keys for factor in self.factors), []) 228 229 def __len__(self) -> int: 230 if not self.factors: 231 return 0 232 length = 1 233 for factor in self.factors: 234 length *= len(factor) 235 return length 236 237 def param_tuples(self) -> Iterator[Params]: 238 def _gen(factors): 239 if not factors: 240 yield () 241 else: 242 first, rest = factors[0], factors[1:] 243 for first_values in first.param_tuples(): 244 for rest_values in _gen(rest): 245 yield first_values + rest_values 246 247 return _gen(self.factors) 248 249 def __repr__(self) -> str: 250 factors_repr = ', '.join(repr(f) for f in self.factors) 251 return f'cirq.Product({factors_repr})' 252 253 def __str__(self) -> str: 254 if not self.factors: 255 return 'Product()' 256 factor_strs = [] 257 for factor in self.factors: 258 factor_str = repr(factor) 259 if isinstance(factor, Zip): 260 factor_str = '(' + str(factor) + ')' 261 factor_strs.append(factor_str) 262 return ' * '.join(factor_strs) 263 264 265class Zip(Sweep): 266 """Zip product (direct sum) of one or more sweeps. 267 268 If one sweep assigns 'a' to values 0, 1, 2, and the second sweep assigns 'b' 269 to the values 3, 4, 5, then the zip is a sweep that assigns to the 270 tuple ('a', 'b') the pair-wise matched values (0, 3), (1, 4), (2, 5). 271 272 When iterating over a Zip, we iterate the individual sweeps in parallel, 273 stopping when the first component sweep stops. For example if one sweep 274 assigns 'a' to values 0, 1 and the second sweep assigns 'b' to the values 275 3, 4, 5, then the zip is a sweep that assigns to the tuple ('a', 'b') the 276 values (0, 3), (1, 4). 277 """ 278 279 def __init__(self, *sweeps: Sweep) -> None: 280 _check_duplicate_keys(sweeps) 281 self.sweeps = sweeps 282 283 def __eq__(self, other): 284 if not isinstance(other, Zip): 285 return NotImplemented 286 return self.sweeps == other.sweeps 287 288 def __hash__(self) -> int: 289 return hash(tuple(self.sweeps)) 290 291 @property 292 def keys(self) -> List['cirq.TParamKey']: 293 return sum((sweep.keys for sweep in self.sweeps), []) 294 295 def __len__(self) -> int: 296 if not self.sweeps: 297 return 0 298 return min(len(sweep) for sweep in self.sweeps) 299 300 def param_tuples(self) -> Iterator[Params]: 301 iters = [sweep.param_tuples() for sweep in self.sweeps] 302 for values in zip(*iters): 303 yield sum(values, ()) 304 305 def __repr__(self) -> str: 306 sweeps_repr = ', '.join(repr(s) for s in self.sweeps) 307 return f'cirq.Zip({sweeps_repr})' 308 309 def __str__(self) -> str: 310 if not self.sweeps: 311 return 'Zip()' 312 return ' + '.join(str(s) if isinstance(s, Product) else repr(s) for s in self.sweeps) 313 314 315class SingleSweep(Sweep): 316 """A simple sweep over one parameter with values from an iterator.""" 317 318 def __init__(self, key: 'cirq.TParamKey') -> None: 319 if isinstance(key, sympy.Symbol): 320 key = str(key) 321 self.key = key 322 323 def __eq__(self, other): 324 if not isinstance(other, self.__class__): 325 return NotImplemented 326 return self._tuple() == other._tuple() 327 328 def __hash__(self) -> int: 329 return hash((self.__class__, self._tuple())) 330 331 @abc.abstractmethod 332 def _tuple(self) -> Tuple[Any, ...]: 333 pass 334 335 @property 336 def keys(self) -> List['cirq.TParamKey']: 337 return [self.key] 338 339 def param_tuples(self) -> Iterator[Params]: 340 for value in self._values(): 341 yield ((self.key, value),) 342 343 @abc.abstractmethod 344 def _values(self) -> Iterator[float]: 345 pass 346 347 348class Points(SingleSweep): 349 """A simple sweep with explicitly supplied values.""" 350 351 def __init__(self, key: 'cirq.TParamKey', points: Sequence['cirq.TParamVal']) -> None: 352 super(Points, self).__init__(key) 353 self.points = points 354 355 def _tuple(self) -> Tuple[Union[str, sympy.Symbol], Sequence[float]]: 356 return self.key, tuple(self.points) 357 358 def __len__(self) -> int: 359 return len(self.points) 360 361 def _values(self) -> Iterator[float]: 362 return iter(self.points) 363 364 def __repr__(self) -> str: 365 return f'cirq.Points({self.key!r}, {self.points!r})' 366 367 368class Linspace(SingleSweep): 369 """A simple sweep over linearly-spaced values.""" 370 371 def __init__(self, key: 'cirq.TParamKey', start: float, stop: float, length: int) -> None: 372 """Creates a linear-spaced sweep for a given key. 373 374 For the given args, assigns to the list of values 375 start, start + (stop - start) / (length - 1), ..., stop 376 """ 377 super(Linspace, self).__init__(key) 378 self.start = start 379 self.stop = stop 380 self.length = length 381 382 def _tuple(self) -> Tuple[Union[str, sympy.Symbol], float, float, int]: 383 return (self.key, self.start, self.stop, self.length) 384 385 def __len__(self) -> int: 386 return self.length 387 388 def _values(self) -> Iterator[float]: 389 if self.length == 1: 390 yield self.start 391 else: 392 for i in range(self.length): 393 p = i / (self.length - 1) 394 yield self.start * (1 - p) + self.stop * p 395 396 def __repr__(self) -> str: 397 return ( 398 f'cirq.Linspace({self.key!r}, start={self.start!r}, ' 399 f'stop={self.stop!r}, length={self.length!r})' 400 ) 401 402 403class ListSweep(Sweep): 404 """A wrapper around a list of `ParamResolver`s.""" 405 406 def __init__(self, resolver_list: Iterable[resolver.ParamResolverOrSimilarType]): 407 """Creates a `Sweep` over a list of `ParamResolver`s. 408 409 Args: 410 resolver_list: The list of parameter resolvers to use in the sweep. 411 All resolvers must resolve the same set of parameters. 412 """ 413 self.resolver_list: List[resolver.ParamResolver] = [] 414 for r in resolver_list: 415 if not isinstance(r, (dict, resolver.ParamResolver)): 416 raise TypeError(f'Not a ParamResolver or dict: <{r!r}>') 417 self.resolver_list.append(resolver.ParamResolver(r)) 418 419 def __eq__(self, other): 420 if not isinstance(other, type(self)): 421 return NotImplemented 422 return self.resolver_list == other.resolver_list 423 424 def __ne__(self, other): 425 return not self == other 426 427 @property 428 def keys(self) -> List['cirq.TParamKey']: 429 if not self.resolver_list: 430 return [] 431 return list(map(str, self.resolver_list[0].param_dict)) 432 433 def __len__(self) -> int: 434 return len(self.resolver_list) 435 436 def param_tuples(self) -> Iterator[Params]: 437 for r in self.resolver_list: 438 yield tuple(_params_without_symbols(r)) 439 440 def __repr__(self) -> str: 441 return f'cirq.ListSweep({self.resolver_list!r})' 442 443 444def _params_without_symbols(resolver: resolver.ParamResolver) -> Params: 445 for sym, val in resolver.param_dict.items(): 446 if isinstance(sym, sympy.Symbol): 447 sym = sym.name 448 yield cast(str, sym), cast(float, val) 449 450 451def dict_to_product_sweep(factor_dict: ProductOrZipSweepLike) -> Product: 452 """Cartesian product of sweeps from a dictionary. 453 454 Each entry in the dictionary specifies a sweep as a mapping from the 455 parameter to a value or sequence of values. The Cartesian product of these 456 sweeps is returned. 457 458 Args: 459 factor_dict: The dictionary containing the sweeps. 460 461 Returns: 462 Cartesian product of the sweeps. 463 """ 464 return Product( 465 *(Points(k, v if isinstance(v, Sequence) else [v]) for k, v in factor_dict.items()) 466 ) 467 468 469def dict_to_zip_sweep(factor_dict: ProductOrZipSweepLike) -> Zip: 470 """Zip product of sweeps from a dictionary. 471 472 Each entry in the dictionary specifies a sweep as a mapping from the 473 parameter to a value or sequence of values. The zip product of these 474 sweeps is returned. 475 476 Args: 477 factor_dict: The dictionary containing the sweeps. 478 479 Returns: 480 Zip product of the sweeps. 481 """ 482 return Zip(*(Points(k, v if isinstance(v, Sequence) else [v]) for k, v in factor_dict.items())) 483