1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import (
15    Any,
16    cast,
17    Dict,
18    Iterable,
19    Iterator,
20    List,
21    overload,
22    Sequence,
23    TYPE_CHECKING,
24    Tuple,
25    Union,
26)
27
28import abc
29import collections
30import itertools
31import sympy
32
33from cirq._doc import document
34from cirq.study import resolver
35
36if TYPE_CHECKING:
37    import cirq
38
39Params = Iterable[Tuple['cirq.TParamKey', 'cirq.TParamVal']]
40ProductOrZipSweepLike = Dict['cirq.TParamKey', Union['cirq.TParamVal', Sequence['cirq.TParamVal']]]
41
42
43def _check_duplicate_keys(sweeps):
44    keys = set()
45    for sweep in sweeps:
46        if any(key in keys for key in sweep.keys):
47            raise ValueError('duplicate keys')
48        keys.update(sweep.keys)
49
50
51class Sweep(metaclass=abc.ABCMeta):
52    """A sweep is an iterator over ParamResolvers.
53
54    A ParamResolver assigns values to Symbols. For sweeps, each ParamResolver
55    must specify the same Symbols that are assigned.  So a sweep is a way to
56    iterate over a set of different values for a fixed set of Symbols. This is
57    useful for a circuit, where there are a fixed set of Symbols, and you want
58    to iterate over an assignment of all values to all symbols.
59
60    For example, a sweep can explicitly assign a set of equally spaced points
61    between two endpoints using a Linspace,
62        sweep = Linspace("angle", start=0.0, end=2.0, length=10)
63    This can then be used with a circuit that has an 'angle' sympy.Symbol to
64    run simulations multiple simulations, one for each of the values in the
65    sweep
66        result = simulator.run_sweep(program=circuit, params=sweep)
67
68    Sweeps support Cartesian and Zip products using the '*' and '+' operators,
69    see the Product and Zip documentation.
70    """
71
72    def __mul__(self, other: 'Sweep') -> 'Sweep':
73        factors = []  # type: List[Sweep]
74        if isinstance(self, Product):
75            factors.extend(self.factors)
76        else:
77            factors.append(self)
78        if isinstance(other, Product):
79            factors.extend(other.factors)
80        elif isinstance(other, Sweep):
81            factors.append(other)
82        else:
83            raise TypeError(f'cannot multiply sweep and {type(other)}')
84        return Product(*factors)
85
86    def __add__(self, other: 'Sweep') -> 'Sweep':
87        sweeps = []  # type: List[Sweep]
88        if isinstance(self, Zip):
89            sweeps.extend(self.sweeps)
90        else:
91            sweeps.append(self)
92        if isinstance(other, Zip):
93            sweeps.extend(other.sweeps)
94        elif isinstance(other, Sweep):
95            sweeps.append(other)
96        else:
97            raise TypeError(f'cannot add sweep and {type(other)}')
98        return Zip(*sweeps)
99
100    @abc.abstractmethod
101    def __eq__(self, other):
102        pass
103
104    def __ne__(self, other):
105        return not self == other
106
107    @property
108    @abc.abstractmethod
109    def keys(self) -> List['cirq.TParamKey']:
110        """The keys for the all of the sympy.Symbols that are resolved."""
111
112    @abc.abstractmethod
113    def __len__(self) -> int:
114        pass
115
116    def __iter__(self) -> Iterator[resolver.ParamResolver]:
117        for params in self.param_tuples():
118            yield resolver.ParamResolver(collections.OrderedDict(params))
119
120    # pylint: disable=function-redefined
121    @overload
122    def __getitem__(self, val: int) -> resolver.ParamResolver:
123        pass
124
125    @overload
126    def __getitem__(self, val: slice) -> 'Sweep':
127        pass
128
129    def __getitem__(self, val):
130        n = len(self)
131        if isinstance(val, int):
132            if val < -n or val >= n:
133                raise IndexError(f'sweep index out of range: {val}')
134            if val < 0:
135                val += n
136            return next(itertools.islice(self, val, val + 1))
137        if not isinstance(val, slice):
138            raise TypeError(f'Sweep indices must be either int or slices, not {type(val)}')
139
140        inds_map: Dict[int, int] = {
141            sweep_i: slice_i for slice_i, sweep_i in enumerate(range(n)[val])
142        }
143        results = [resolver.ParamResolver()] * len(inds_map)
144        for i, item in enumerate(self):
145            if i in inds_map:
146                results[inds_map[i]] = item
147
148        return ListSweep(results)
149
150    # pylint: enable=function-redefined
151
152    @abc.abstractmethod
153    def param_tuples(self) -> Iterator[Params]:
154        """An iterator over (key, value) pairs assigning Symbol key to value."""
155
156    def __str__(self) -> str:
157        length = len(self)
158        max_show = 10
159        # Show a maximum of max_show entries with an ellipsis in the middle
160        if length > max_show:
161            beginning_len = max_show - max_show // 2
162        else:
163            beginning_len = max_show
164        end_len = max_show - beginning_len
165        lines = ['Sweep:']
166        lines.extend(str(dict(r.param_dict)) for r in itertools.islice(self, beginning_len))
167        if end_len > 0:
168            lines.append('...')
169            lines.extend(
170                str(dict(r.param_dict)) for r in itertools.islice(self, length - end_len, length)
171            )
172        return '\n'.join(lines)
173
174
175class _Unit(Sweep):
176    """A sweep with a single element that assigns no parameter values.
177
178    This is useful as a base sweep, instead of special casing None.
179    """
180
181    def __eq__(self, other):
182        if not isinstance(other, self.__class__):
183            return NotImplemented
184        return True
185
186    @property
187    def keys(self) -> List['cirq.TParamKey']:
188        return []
189
190    def __len__(self) -> int:
191        return 1
192
193    def param_tuples(self) -> Iterator[Params]:
194        yield ()
195
196    def __repr__(self) -> str:
197        return 'cirq.UnitSweep'
198
199
200UnitSweep = _Unit()
201document(UnitSweep, """The singleton sweep with no parameters.""")
202
203
204class Product(Sweep):
205    """Cartesian product of one or more sweeps.
206
207    If one sweep assigns 'a' to the values 0, 1, 2, and the second sweep
208    assigns 'b' to the values 2, 3, then the product is a sweep that
209    assigns the tuple ('a','b') to all possible combinations of these
210    assignments: (0, 2), (1, 2), (2, 2), (0, 3), (1, 3), (2, 3).
211    """
212
213    def __init__(self, *factors: Sweep) -> None:
214        _check_duplicate_keys(factors)
215        self.factors = factors
216
217    def __eq__(self, other):
218        if not isinstance(other, Product):
219            return NotImplemented
220        return self.factors == other.factors
221
222    def __hash__(self):
223        return hash(tuple(self.factors))
224
225    @property
226    def keys(self) -> List['cirq.TParamKey']:
227        return sum((factor.keys for factor in self.factors), [])
228
229    def __len__(self) -> int:
230        if not self.factors:
231            return 0
232        length = 1
233        for factor in self.factors:
234            length *= len(factor)
235        return length
236
237    def param_tuples(self) -> Iterator[Params]:
238        def _gen(factors):
239            if not factors:
240                yield ()
241            else:
242                first, rest = factors[0], factors[1:]
243                for first_values in first.param_tuples():
244                    for rest_values in _gen(rest):
245                        yield first_values + rest_values
246
247        return _gen(self.factors)
248
249    def __repr__(self) -> str:
250        factors_repr = ', '.join(repr(f) for f in self.factors)
251        return f'cirq.Product({factors_repr})'
252
253    def __str__(self) -> str:
254        if not self.factors:
255            return 'Product()'
256        factor_strs = []
257        for factor in self.factors:
258            factor_str = repr(factor)
259            if isinstance(factor, Zip):
260                factor_str = '(' + str(factor) + ')'
261            factor_strs.append(factor_str)
262        return ' * '.join(factor_strs)
263
264
265class Zip(Sweep):
266    """Zip product (direct sum) of one or more sweeps.
267
268    If one sweep assigns 'a' to values 0, 1, 2, and the second sweep assigns 'b'
269    to the values 3, 4, 5, then the zip is a sweep that assigns to the
270    tuple ('a', 'b') the pair-wise matched values (0, 3), (1, 4), (2, 5).
271
272    When iterating over a Zip, we iterate the individual sweeps in parallel,
273    stopping when the first component sweep stops. For example if one sweep
274    assigns 'a' to values 0, 1 and the second sweep assigns 'b' to the values
275    3, 4, 5, then the zip is a sweep that assigns to the tuple ('a', 'b') the
276    values (0, 3), (1, 4).
277    """
278
279    def __init__(self, *sweeps: Sweep) -> None:
280        _check_duplicate_keys(sweeps)
281        self.sweeps = sweeps
282
283    def __eq__(self, other):
284        if not isinstance(other, Zip):
285            return NotImplemented
286        return self.sweeps == other.sweeps
287
288    def __hash__(self) -> int:
289        return hash(tuple(self.sweeps))
290
291    @property
292    def keys(self) -> List['cirq.TParamKey']:
293        return sum((sweep.keys for sweep in self.sweeps), [])
294
295    def __len__(self) -> int:
296        if not self.sweeps:
297            return 0
298        return min(len(sweep) for sweep in self.sweeps)
299
300    def param_tuples(self) -> Iterator[Params]:
301        iters = [sweep.param_tuples() for sweep in self.sweeps]
302        for values in zip(*iters):
303            yield sum(values, ())
304
305    def __repr__(self) -> str:
306        sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
307        return f'cirq.Zip({sweeps_repr})'
308
309    def __str__(self) -> str:
310        if not self.sweeps:
311            return 'Zip()'
312        return ' + '.join(str(s) if isinstance(s, Product) else repr(s) for s in self.sweeps)
313
314
315class SingleSweep(Sweep):
316    """A simple sweep over one parameter with values from an iterator."""
317
318    def __init__(self, key: 'cirq.TParamKey') -> None:
319        if isinstance(key, sympy.Symbol):
320            key = str(key)
321        self.key = key
322
323    def __eq__(self, other):
324        if not isinstance(other, self.__class__):
325            return NotImplemented
326        return self._tuple() == other._tuple()
327
328    def __hash__(self) -> int:
329        return hash((self.__class__, self._tuple()))
330
331    @abc.abstractmethod
332    def _tuple(self) -> Tuple[Any, ...]:
333        pass
334
335    @property
336    def keys(self) -> List['cirq.TParamKey']:
337        return [self.key]
338
339    def param_tuples(self) -> Iterator[Params]:
340        for value in self._values():
341            yield ((self.key, value),)
342
343    @abc.abstractmethod
344    def _values(self) -> Iterator[float]:
345        pass
346
347
348class Points(SingleSweep):
349    """A simple sweep with explicitly supplied values."""
350
351    def __init__(self, key: 'cirq.TParamKey', points: Sequence['cirq.TParamVal']) -> None:
352        super(Points, self).__init__(key)
353        self.points = points
354
355    def _tuple(self) -> Tuple[Union[str, sympy.Symbol], Sequence[float]]:
356        return self.key, tuple(self.points)
357
358    def __len__(self) -> int:
359        return len(self.points)
360
361    def _values(self) -> Iterator[float]:
362        return iter(self.points)
363
364    def __repr__(self) -> str:
365        return f'cirq.Points({self.key!r}, {self.points!r})'
366
367
368class Linspace(SingleSweep):
369    """A simple sweep over linearly-spaced values."""
370
371    def __init__(self, key: 'cirq.TParamKey', start: float, stop: float, length: int) -> None:
372        """Creates a linear-spaced sweep for a given key.
373
374        For the given args, assigns to the list of values
375            start, start + (stop - start) / (length - 1), ..., stop
376        """
377        super(Linspace, self).__init__(key)
378        self.start = start
379        self.stop = stop
380        self.length = length
381
382    def _tuple(self) -> Tuple[Union[str, sympy.Symbol], float, float, int]:
383        return (self.key, self.start, self.stop, self.length)
384
385    def __len__(self) -> int:
386        return self.length
387
388    def _values(self) -> Iterator[float]:
389        if self.length == 1:
390            yield self.start
391        else:
392            for i in range(self.length):
393                p = i / (self.length - 1)
394                yield self.start * (1 - p) + self.stop * p
395
396    def __repr__(self) -> str:
397        return (
398            f'cirq.Linspace({self.key!r}, start={self.start!r}, '
399            f'stop={self.stop!r}, length={self.length!r})'
400        )
401
402
403class ListSweep(Sweep):
404    """A wrapper around a list of `ParamResolver`s."""
405
406    def __init__(self, resolver_list: Iterable[resolver.ParamResolverOrSimilarType]):
407        """Creates a `Sweep` over a list of `ParamResolver`s.
408
409        Args:
410            resolver_list: The list of parameter resolvers to use in the sweep.
411                All resolvers must resolve the same set of parameters.
412        """
413        self.resolver_list: List[resolver.ParamResolver] = []
414        for r in resolver_list:
415            if not isinstance(r, (dict, resolver.ParamResolver)):
416                raise TypeError(f'Not a ParamResolver or dict: <{r!r}>')
417            self.resolver_list.append(resolver.ParamResolver(r))
418
419    def __eq__(self, other):
420        if not isinstance(other, type(self)):
421            return NotImplemented
422        return self.resolver_list == other.resolver_list
423
424    def __ne__(self, other):
425        return not self == other
426
427    @property
428    def keys(self) -> List['cirq.TParamKey']:
429        if not self.resolver_list:
430            return []
431        return list(map(str, self.resolver_list[0].param_dict))
432
433    def __len__(self) -> int:
434        return len(self.resolver_list)
435
436    def param_tuples(self) -> Iterator[Params]:
437        for r in self.resolver_list:
438            yield tuple(_params_without_symbols(r))
439
440    def __repr__(self) -> str:
441        return f'cirq.ListSweep({self.resolver_list!r})'
442
443
444def _params_without_symbols(resolver: resolver.ParamResolver) -> Params:
445    for sym, val in resolver.param_dict.items():
446        if isinstance(sym, sympy.Symbol):
447            sym = sym.name
448        yield cast(str, sym), cast(float, val)
449
450
451def dict_to_product_sweep(factor_dict: ProductOrZipSweepLike) -> Product:
452    """Cartesian product of sweeps from a dictionary.
453
454    Each entry in the dictionary specifies a sweep as a mapping from the
455    parameter to a value or sequence of values. The Cartesian product of these
456    sweeps is returned.
457
458    Args:
459        factor_dict: The dictionary containing the sweeps.
460
461    Returns:
462        Cartesian product of the sweeps.
463    """
464    return Product(
465        *(Points(k, v if isinstance(v, Sequence) else [v]) for k, v in factor_dict.items())
466    )
467
468
469def dict_to_zip_sweep(factor_dict: ProductOrZipSweepLike) -> Zip:
470    """Zip product of sweeps from a dictionary.
471
472    Each entry in the dictionary specifies a sweep as a mapping from the
473    parameter to a value or sequence of values. The zip product of these
474    sweeps is returned.
475
476    Args:
477        factor_dict: The dictionary containing the sweeps.
478
479    Returns:
480        Zip product of the sweeps.
481    """
482    return Zip(*(Points(k, v if isinstance(v, Sequence) else [v]) for k, v in factor_dict.items()))
483