1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Defines which types are Sweepable."""
16
17from typing import Iterable, Iterator, List, Sequence, Union, cast
18import warnings
19from typing_extensions import Protocol
20
21from cirq._doc import document
22from cirq.study.resolver import ParamResolver, ParamResolverOrSimilarType
23from cirq.study.sweeps import ListSweep, Points, Sweep, UnitSweep, Zip, dict_to_product_sweep
24
25SweepLike = Union[ParamResolverOrSimilarType, Sweep]
26document(SweepLike, """An object similar to an iterable of parameter resolvers.""")
27
28
29class _Sweepable(Protocol):
30    """An intermediate class allowing for recursive definition of Sweepable,
31    since recursive union definitions are not yet supported in mypy."""
32
33    def __iter__(self) -> Iterator[Union[SweepLike, '_Sweepable']]:
34        pass
35
36
37Sweepable = Union[SweepLike, _Sweepable]
38document(
39    Sweepable,
40    """An object or collection of objects representing a parameter sweep.""",
41)
42
43
44def to_resolvers(sweepable: Sweepable) -> Iterator[ParamResolver]:
45    """Convert a Sweepable to a list of ParamResolvers."""
46    for sweep in to_sweeps(sweepable):
47        yield from sweep
48
49
50def to_sweeps(sweepable: Sweepable) -> List[Sweep]:
51    """Converts a Sweepable to a list of Sweeps."""
52    if sweepable is None:
53        return [UnitSweep]
54    if isinstance(sweepable, ParamResolver):
55        return [_resolver_to_sweep(sweepable)]
56    if isinstance(sweepable, Sweep):
57        return [sweepable]
58    if isinstance(sweepable, dict):
59        if any(isinstance(val, Sequence) for val in sweepable.values()):
60            warnings.warn(
61                'Implicit expansion of a dictionary into a Cartesian product '
62                'of sweeps is deprecated and will be removed in cirq 0.10. '
63                'Instead, expand the sweep explicitly using '
64                '`cirq.dict_to_product_sweep`.',
65                DeprecationWarning,
66                stacklevel=2,
67            )
68        product_sweep = dict_to_product_sweep(sweepable)
69        return [_resolver_to_sweep(resolver) for resolver in product_sweep]
70    if isinstance(sweepable, Iterable) and not isinstance(sweepable, str):
71        return [sweep for item in sweepable for sweep in to_sweeps(item)]
72    raise TypeError(f'Unrecognized sweepable type: {type(sweepable)}.\nsweepable: {sweepable}')
73
74
75# TODO(#3388) Add documentation for Raises.
76# pylint: disable=missing-raises-doc
77def to_sweep(
78    sweep_or_resolver_list: Union[
79        'Sweep', ParamResolverOrSimilarType, Iterable[ParamResolverOrSimilarType]
80    ]
81) -> 'Sweep':
82    """Converts the argument into a ``cirq.Sweep``.
83
84    Args:
85        sweep_or_resolver_list: The object to try to turn into a
86            ``cirq.Sweep`` . A ``cirq.Sweep``, a single ``cirq.ParamResolver``,
87            or a list of ``cirq.ParamResolver`` s.
88
89    Returns:
90        A sweep equal to or containing the argument.
91    """
92    if isinstance(sweep_or_resolver_list, Sweep):
93        return sweep_or_resolver_list
94    if isinstance(sweep_or_resolver_list, (ParamResolver, dict)):
95        resolver = cast(ParamResolverOrSimilarType, sweep_or_resolver_list)
96        return ListSweep([resolver])
97    if isinstance(sweep_or_resolver_list, Iterable):
98        resolver_iter = cast(Iterable[ParamResolverOrSimilarType], sweep_or_resolver_list)
99        return ListSweep(resolver_iter)
100    raise TypeError(f'Unexpected sweep-like value: {sweep_or_resolver_list}')
101
102
103# pylint: enable=missing-raises-doc
104def _resolver_to_sweep(resolver: ParamResolver) -> Sweep:
105    params = resolver.param_dict
106    if not params:
107        return UnitSweep
108    return Zip(*[Points(key, [cast(float, value)]) for key, value in params.items()])
109