1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Protocol for objects that are mixtures (probabilistic combinations)."""
15from typing import Any, Sequence, Tuple, Union
16
17import numpy as np
18from typing_extensions import Protocol
19
20from cirq._doc import doc_private
21from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
22from cirq.protocols.has_unitary_protocol import has_unitary
23from cirq.type_workarounds import NotImplementedType
24
25# This is a special indicator value used by the inverse method to determine
26# whether or not the caller provided a 'default' argument.
27RaiseTypeErrorIfNotProvided = ((0.0, []),)  # type: Sequence[Tuple[float, Any]]
28
29
30class SupportsMixture(Protocol):
31    """An object that decomposes into a probability distribution of unitaries."""
32
33    @doc_private
34    def _mixture_(self) -> Union[Sequence[Tuple[float, Any]], NotImplementedType]:
35        """Decompose into a probability distribution of unitaries.
36
37        This method is used by the global `cirq.mixture` method.
38
39        A mixture is described by an iterable of tuples of the form
40
41            (probability of unitary, unitary as numpy array)
42
43        The probability components of the tuples must sum to 1.0 and be between
44        0 and 1 (inclusive).
45
46        Returns:
47            A list of (probability, unitary) pairs.
48        """
49
50    @doc_private
51    def _has_mixture_(self) -> bool:
52        """Whether this value has a mixture representation.
53
54        This method is used by the global `cirq.has_mixture` method.  If this
55        method is not present, or returns NotImplemented, it will fallback
56        to using _mixture_ with a default value, or False if neither exist.
57
58        Returns:
59          True if the value has a mixture representation, Falseotherwise.
60        """
61
62
63# TODO(#3388) Add documentation for Raises.
64# pylint: disable=missing-raises-doc
65def mixture(
66    val: Any, default: Any = RaiseTypeErrorIfNotProvided
67) -> Sequence[Tuple[float, np.ndarray]]:
68    """Return a sequence of tuples representing a probabilistic unitary.
69
70    A mixture is described by an iterable of tuples of the form
71
72        (probability of unitary, unitary as numpy array)
73
74    The probability components of the tuples must sum to 1.0 and be
75    non-negative.
76
77    Args:
78        val: The value to decompose into a mixture of unitaries.
79        default: A default value if val does not support mixture.
80
81    Returns:
82        An iterable of tuples of size 2. The first element of the tuple is a
83        probability (between 0 and 1) and the second is the object that occurs
84        with that probability in the mixture. The probabilities will sum to 1.0.
85    """
86
87    mixture_getter = getattr(val, '_mixture_', None)
88    result = NotImplemented if mixture_getter is None else mixture_getter()
89    if result is not NotImplemented:
90        return result
91
92    unitary_getter = getattr(val, '_unitary_', None)
93    result = NotImplemented if unitary_getter is None else unitary_getter()
94    if result is not NotImplemented:
95        return ((1.0, result),)
96
97    if default is not RaiseTypeErrorIfNotProvided:
98        return default
99
100    if mixture_getter is None and unitary_getter is None:
101        raise TypeError(f"object of type '{type(val)}' has no _mixture_ or _unitary_ method.")
102
103    raise TypeError(
104        "object of type '{}' does have a _mixture_ or _unitary_ "
105        "method, but it returned NotImplemented.".format(type(val))
106    )
107
108
109# pylint: enable=missing-raises-doc
110def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool:
111    """Returns whether the value has a mixture representation.
112
113    Args:
114        val: The value to check.
115        allow_decompose: Used by internal methods to stop redundant
116            decompositions from being performed (e.g. there's no need to
117            decompose an object to check if it is unitary as part of determining
118            if the object is a quantum channel, when the quantum channel check
119            will already be doing a more general decomposition check). Defaults
120            to True. When false, the decomposition strategy for determining
121            the result is skipped.
122
123    Returns:
124        If `val` has a `_has_mixture_` method and its result is not
125        NotImplemented, that result is returned. Otherwise, if the value
126        has a `_mixture_` method return True if that has a non-default value.
127        Returns False if neither function exists.
128    """
129    mixture_getter = getattr(val, '_has_mixture_', None)
130    result = NotImplemented if mixture_getter is None else mixture_getter()
131    if result is not NotImplemented:
132        return result
133
134    if has_unitary(val, allow_decompose=False):
135        return True
136
137    if allow_decompose:
138        operations, _, _ = _try_decompose_into_operations_and_qubits(val)
139        if operations is not None:
140            return all(has_mixture(val) for val in operations)
141
142    # No _has_mixture_ or _has_unitary_ function, use _mixture_ instead.
143    return mixture(val, None) is not None
144
145
146def validate_mixture(supports_mixture: SupportsMixture):
147    """Validates that the mixture's tuple are valid probabilities."""
148    mixture_tuple = mixture(supports_mixture, None)
149    if mixture_tuple is None:
150        raise TypeError(f'{supports_mixture}_mixture did not have a _mixture_ method')
151
152    def validate_probability(p, p_str):
153        if p < 0:
154            raise ValueError(f'{p_str} was less than 0.')
155        elif p > 1:
156            raise ValueError(f'{p_str} was greater than 1.')
157
158    total = 0.0
159    for p, val in mixture_tuple:
160        validate_probability(p, '{}\'s probability'.format(str(val)))
161        total += p
162    if not np.isclose(total, 1.0):
163        raise ValueError('Sum of probabilities of a mixture was not 1.0')
164