1# This file is part of QuTiP: Quantum Toolbox in Python.
2#
3#    Copyright (c) 2011 and later, Paul D. Nation and Robert J. Johansson.
4#    All rights reserved.
5#
6#    Redistribution and use in source and binary forms, with or without
7#    modification, are permitted provided that the following conditions are
8#    met:
9#
10#    1. Redistributions of source code must retain the above copyright notice,
11#       this list of conditions and the following disclaimer.
12#
13#    2. Redistributions in binary form must reproduce the above copyright
14#       notice, this list of conditions and the following disclaimer in the
15#       documentation and/or other materials provided with the distribution.
16#
17#    3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names
18#       of its contributors may be used to endorse or promote products derived
19#       from this software without specific prior written permission.
20#
21#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22#    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23#    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24#    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25#    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26#    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27#    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28#    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29#    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31#    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32###############################################################################
33
34import pytest
35import collections
36import functools
37import numpy as np
38import qutip
39
40# We want to test the broadcasting rules for `qutip.expect` for a whole bunch
41# of different systems, without having to repeatedly specify the systems over
42# and over again.  We first store a small number of test cases for known
43# expectation value in the most bundled-up form, because it's easier to unroll
44# these by applying the expected broadcasting rules explicitly ourselves than
45# performing the inverse operation.
46#
47# We store a single test case in a record type, just to keep things neatly
48# together while we're munging them, so it's clear at all times what
49# constitutes a valid test case.
50
51_Case = collections.namedtuple('_Case', ['operator', 'state', 'expected'])
52
53
54def _case_to_dm(case):
55    return case._replace(state=[x.proj() for x in case.state])
56
57
58def _unwrap(list_):
59    """Unwrap lists until we reach the first non-list element."""
60    out = list_
61    while isinstance(out, list):
62        out = out[0]
63    return out
64
65
66def _case_id(case):
67    op_part = 'qubit' if _unwrap(case.operator).dims[0][0] == 2 else 'basis'
68    state_part = 'ket' if _unwrap(case.state).dims[1][0] == 1 else 'dm'
69    return op_part + "-" + state_part
70
71
72# This is the minimal set of test cases, with a Fock system and a qubit system
73# both in ket form and dm form.  The reference expectations are a 2D array
74# which would be found by broadcasting `operator` against `state` and applying
75# `qutip.expect` to the pairs.
76_dim = 5
77_num, _a = qutip.num(_dim), qutip.destroy(_dim)
78_sx, _sz, _sp = qutip.sigmax(), qutip.sigmaz(), qutip.sigmap()
79_known_fock = _Case([_num, _a],
80                    [qutip.fock(_dim, n) for n in range(_dim)],
81                    np.array([np.arange(_dim), np.zeros(_dim)]))
82_known_qubit = _Case([_sx, _sz, _sp],
83                     [qutip.basis(2, 0), qutip.basis(2, 1)],
84                     np.array([[0, 0], [1, -1], [0, 0]]))
85_known_cases = [_known_fock, _case_to_dm(_known_fock),
86                _known_qubit, _case_to_dm(_known_qubit)]
87
88
89class TestKnownExpectation:
90    def pytest_generate_tests(self, metafunc):
91        """
92        Perform the parametrisation over the test cases, performing the
93        explicit broadcasting into separate test cases when required.
94
95        We detect whether to perform explicit broadcasting over one of the
96        arguments of the `_Case` by looking for a singular/plural name of the
97        parameter in the test.  If the parameter is singular, then we manually
98        perform the broadcasting rule for that fixture, and parametrise over
99        the resulting list, taking care to pick out the correct parts of the
100        reference array.
101        """
102        cases = _known_cases
103        op_name, state_name = 'operator', 'state'
104        if op_name not in metafunc.fixturenames:
105            op_name += 's'
106        else:
107            cases = [_Case(op, case.state, expected)
108                     for case in cases
109                     for op, expected in zip(case.operator, case.expected)]
110        if state_name not in metafunc.fixturenames:
111            state_name += 's'
112        else:
113            cases = [_Case(case.operator, state, expected)
114                     for case in cases
115                     for state, expected in zip(case.state, case.expected.T)]
116        metafunc.parametrize([op_name, state_name, 'expected'], cases,
117                             ids=[_case_id(case) for case in cases])
118
119    def test_operator_by_basis(self, operator, state, expected):
120        result = qutip.expect(operator, state)
121        assert result == expected
122        assert isinstance(result, float if operator.isherm else complex)
123
124    def test_broadcast_operator_list(self, operators, state, expected):
125        result = qutip.expect(operators, state)
126        expected_dtype = (np.float64 if all(op.isherm for op in operators)
127                          else np.complex128)
128        assert isinstance(result, np.ndarray)
129        assert result.dtype == expected_dtype
130        assert list(result) == list(expected)
131
132    def test_broadcast_state_list(self, operator, states, expected):
133        result = qutip.expect(operator, states)
134        expected_dtype = np.float64 if operator.isherm else np.complex128
135        assert isinstance(result, np.ndarray)
136        assert result.dtype == expected_dtype
137        assert list(result) == list(expected)
138
139    def test_broadcast_both_lists(self, operators, states, expected):
140        result = qutip.expect(operators, states)
141        assert len(result) == len(operators)
142        for part, operator, expected_part in zip(result, operators, expected):
143            expected_dtype = np.float64 if operator.isherm else np.complex128
144            assert isinstance(part, np.ndarray)
145            assert part.dtype == expected_dtype
146            assert list(part) == list(expected_part)
147
148
149@pytest.mark.repeat(20)
150@pytest.mark.parametrize("hermitian", [False, True], ids=['complex', 'real'])
151def test_equivalent_to_matrix_element(hermitian):
152    dimension = 20
153    state = qutip.rand_ket(dimension, 0.3)
154    op = qutip.rand_herm(dimension, 0.2)
155    if not hermitian:
156        op = op + 1j*qutip.rand_herm(dimension, 0.1)
157    expected = (state.dag() * op * state).data[0, 0]
158    assert abs(qutip.expect(op, state) - expected) < 1e-14
159
160
161@pytest.mark.parametrize("solve", [
162    pytest.param(qutip.sesolve, id="sesolve"),
163    pytest.param(functools.partial(qutip.mesolve, c_ops=[qutip.qzero(2)]),
164                 id="mesolve"),
165])
166def test_compatibility_with_solver(solve):
167    e_ops = [getattr(qutip, 'sigma'+x)() for x in 'xyzmp']
168    h = qutip.sigmax()
169    state = qutip.basis(2, 0)
170    times = np.linspace(0, 10, 101)
171    options = qutip.Options(store_states=True)
172    result = solve(h, state, times, e_ops=e_ops, options=options)
173    direct, states = result.expect, result.states
174    indirect = qutip.expect(e_ops, states)
175    assert len(direct) == len(indirect)
176    for direct_, indirect_ in zip(direct, indirect):
177        assert len(direct_) == len(indirect_)
178        assert isinstance(direct_, np.ndarray)
179        assert isinstance(indirect_, np.ndarray)
180        assert direct_.dtype == indirect_.dtype
181        np.testing.assert_allclose(direct_, indirect_, atol=1e-12)
182