1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Protocol for objects that are mixtures (probabilistic combinations)."""
15from typing import Any, Sequence, Tuple, Union
16
17import numpy as np
18from typing_extensions import Protocol
19
20from cirq._doc import doc_private
21from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
22from cirq.protocols.has_unitary_protocol import has_unitary
23from cirq.type_workarounds import NotImplementedType
24
25# This is a special indicator value used by the inverse method to determine
26# whether or not the caller provided a 'default' argument.
27RaiseTypeErrorIfNotProvided = ((0.0, []),)  # type: Sequence[Tuple[float, Any]]
28
29
30class SupportsMixture(Protocol):
31    """An object that decomposes into a probability distribution of unitaries."""
32
33    @doc_private
34    def _mixture_(self) -> Union[Sequence[Tuple[float, Any]], NotImplementedType]:
35        """Decompose into a probability distribution of unitaries.
36
37        This method is used by the global `cirq.mixture` method.
38
39        A mixture is described by an iterable of tuples of the form
40
41            (probability of unitary, unitary as numpy array)
42
43        The probability components of the tuples must sum to 1.0 and be between
44        0 and 1 (inclusive).
45
46        Returns:
47            A list of (probability, unitary) pairs.
48        """
49
50    @doc_private
51    def _has_mixture_(self) -> bool:
52        """Whether this value has a mixture representation.
53
54        This method is used by the global `cirq.has_mixture` method.  If this
55        method is not present, or returns NotImplemented, it will fallback
56        to using _mixture_ with a default value, or False if neither exist.
57
58        Returns:
59          True if the value has a mixture representation, Falseotherwise.
60        """
61
62
63def mixture(
64    val: Any, default: Any = RaiseTypeErrorIfNotProvided
65) -> Sequence[Tuple[float, np.ndarray]]:
66    """Return a sequence of tuples representing a probabilistic unitary.
67
68    A mixture is described by an iterable of tuples of the form
69
70        (probability of unitary, unitary as numpy array)
71
72    The probability components of the tuples must sum to 1.0 and be
73    non-negative.
74
75    Args:
76        val: The value to decompose into a mixture of unitaries.
77        default: A default value if val does not support mixture.
78
79    Returns:
80        An iterable of tuples of size 2. The first element of the tuple is a
81        probability (between 0 and 1) and the second is the object that occurs
82        with that probability in the mixture. The probabilities will sum to 1.0.
83    """
84
85    mixture_getter = getattr(val, '_mixture_', None)
86    result = NotImplemented if mixture_getter is None else mixture_getter()
87    if result is not NotImplemented:
88        return result
89
90    unitary_getter = getattr(val, '_unitary_', None)
91    result = NotImplemented if unitary_getter is None else unitary_getter()
92    if result is not NotImplemented:
93        return ((1.0, result),)
94
95    if default is not RaiseTypeErrorIfNotProvided:
96        return default
97
98    if mixture_getter is None and unitary_getter is None:
99        raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.")
100
101    raise TypeError(
102        "object of type '{}' does have a _mixture_ or _unitary_ "
103        "method, but it returned NotImplemented.".format(type(val))
104    )
105
106
107def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool:
108    """Returns whether the value has a mixture representation.
109
110    Args:
111        val: The value to check.
112        allow_decompose: Used by internal methods to stop redundant
113            decompositions from being performed (e.g. there's no need to
114            decompose an object to check if it is unitary as part of determining
115            if the object is a quantum channel, when the quantum channel check
116            will already be doing a more general decomposition check). Defaults
117            to True. When false, the decomposition strategy for determining
118            the result is skipped.
119
120    Returns:
121        If `val` has a `_has_mixture_` method and its result is not
122        NotImplemented, that result is returned. Otherwise, if the value
123        has a `_mixture_` method return True if that has a non-default value.
124        Returns False if neither function exists.
125    """
126    mixture_getter = getattr(val, '_has_mixture_', None)
127    result = NotImplemented if mixture_getter is None else mixture_getter()
128    if result is not NotImplemented:
129        return result
130
131    if has_unitary(val, allow_decompose=False):
132        return True
133
134    if allow_decompose:
135        operations, _, _ = _try_decompose_into_operations_and_qubits(val)
136        if operations is not None:
137            return all(has_mixture(val) for val in operations)
138
139    # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead.
140    return mixture(val, None) is not None
141
142
143def validate_mixture(supports_mixture: SupportsMixture):
144    """Validates that the mixture's tuple are valid probabilities."""
145    mixture_tuple = mixture(supports_mixture, None)
146    if mixture_tuple is None:
147        raise TypeError(f'{supports_mixture}_mixture did not have a _mixture_ method')
148
149    def validate_probability(p, p_str):
150        if p < 0:
151            raise ValueError(f'{p_str} was less than 0.')
152        elif p > 1:
153            raise ValueError(f'{p_str} was greater than 1.')
154
155    total = 0.0
156    for p, val in mixture_tuple:
157        validate_probability(p, '{}\'s probability'.format(str(val)))
158        total += p
159    if not np.isclose(total, 1.0):
160        raise ValueError('Sum of probabilities of a mixture was not 1.0')
161