1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utilities for handling probabilities."""
16
17from typing import TYPE_CHECKING
18
19import numpy as np
20from cirq.qis import to_valid_state_vector
21
22if TYPE_CHECKING:
23    import cirq
24
25
26def validate_probability(p: float, p_str: str) -> float:
27    """Validates that a probability is between 0 and 1 inclusively.
28
29    Args:
30        p: The value to validate.
31        p_str: What to call the probability in error messages.
32
33    Returns:
34        The probability p if the probability if valid.
35
36    Raises:
37        ValueError: If the probability is invalid.
38    """
39    if p < 0:
40        raise ValueError(f'{p_str} was less than 0.')
41    elif p > 1:
42        raise ValueError(f'{p_str} was greater than 1.')
43    return p
44
45
46def state_vector_to_probabilities(state_vector: 'cirq.STATE_VECTOR_LIKE') -> np.ndarray:
47    valid_state_vector = to_valid_state_vector(state_vector)
48    return np.abs(valid_state_vector) ** 2
49