1# This file is part of QuTiP: Quantum Toolbox in Python.
2#
3#    Copyright (c) 2011 and later, The QuTiP Project.
4#    All rights reserved.
5#
6#    Redistribution and use in source and binary forms, with or without
7#    modification, are permitted provided that the following conditions are
8#    met:
9#
10#    1. Redistributions of source code must retain the above copyright notice,
11#       this list of conditions and the following disclaimer.
12#
13#    2. Redistributions in binary form must reproduce the above copyright
14#       notice, this list of conditions and the following disclaimer in the
15#       documentation and/or other materials provided with the distribution.
16#
17#    3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names
18#       of its contributors may be used to endorse or promote products derived
19#       from this software without specific prior written permission.
20#
21#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22#    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23#    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
24#    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25#    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26#    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27#    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28#    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29#    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31#    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32###############################################################################
33import numpy as np
34from numpy.testing import (run_module_suite, assert_,
35                        assert_equal, assert_almost_equal)
36import scipy.sparse as sp
37
38from qutip.fastsparse import fast_csr_matrix, fast_identity
39from qutip.random_objects import (rand_dm, rand_herm,
40                                  rand_ket, rand_unitary)
41from qutip.cy.spmath import (zcsr_kron, zcsr_transpose, zcsr_adjoint,
42                            zcsr_isherm)
43
44
45def test_csr_kron():
46    "spmath: zcsr_kron"
47    num_test = 5
48    for _ in range(num_test):
49        ra = np.random.randint(2,100)
50        rb = np.random.randint(2,100)
51        A = rand_herm(ra,0.5).data
52        B = rand_herm(rb,0.5).data
53        As = A.tocsr(1)
54        Bs = B.tocsr(1)
55        C = sp.kron(As,Bs, format='csr')
56        D = zcsr_kron(A, B)
57        assert_almost_equal(C.data, D.data)
58        assert_equal(C.indices, D.indices)
59        assert_equal(C.indptr, D.indptr)
60
61    for _ in range(num_test):
62        ra = np.random.randint(2,100)
63        rb = np.random.randint(2,100)
64        A = rand_ket(ra,0.5).data
65        B = rand_herm(rb,0.5).data
66        As = A.tocsr(1)
67        Bs = B.tocsr(1)
68        C = sp.kron(As,Bs, format='csr')
69        D = zcsr_kron(A, B)
70        assert_almost_equal(C.data, D.data)
71        assert_equal(C.indices, D.indices)
72        assert_equal(C.indptr, D.indptr)
73
74    for _ in range(num_test):
75        ra = np.random.randint(2,100)
76        rb = np.random.randint(2,100)
77        A = rand_dm(ra,0.5).data
78        B = rand_herm(rb,0.5).data
79        As = A.tocsr(1)
80        Bs = B.tocsr(1)
81        C = sp.kron(As,Bs, format='csr')
82        D = zcsr_kron(A, B)
83        assert_almost_equal(C.data, D.data)
84        assert_equal(C.indices, D.indices)
85        assert_equal(C.indptr, D.indptr)
86
87    for _ in range(num_test):
88        ra = np.random.randint(2,100)
89        rb = np.random.randint(2,100)
90        A = rand_ket(ra,0.5).data
91        B = rand_ket(rb,0.5).data
92        As = A.tocsr(1)
93        Bs = B.tocsr(1)
94        C = sp.kron(As,Bs, format='csr')
95        D = zcsr_kron(A, B)
96        assert_almost_equal(C.data, D.data)
97        assert_equal(C.indices, D.indices)
98        assert_equal(C.indptr, D.indptr)
99
100
101def test_zcsr_transpose():
102    "spmath: zcsr_transpose"
103    for k in range(50):
104        ra = np.random.randint(2,100)
105        A = rand_ket(ra,0.5).data
106        B = A.T.tocsr()
107        C = A.trans()
108        x = np.all(B.data == C.data)
109        y = np.all(B.indices == C.indices)
110        z = np.all(B.indptr == C.indptr)
111        assert_(x*y*z)
112
113    for k in range(50):
114        ra = np.random.randint(2,100)
115        A = rand_herm(5,1.0/ra).data
116        B = A.T.tocsr()
117        C = A.trans()
118        x = np.all(B.data == C.data)
119        y = np.all(B.indices == C.indices)
120        z = np.all(B.indptr == C.indptr)
121        assert_(x*y*z)
122
123    for k in range(50):
124        ra = np.random.randint(2,100)
125        A = rand_dm(ra,1.0/ra).data
126        B = A.T.tocsr()
127        C = A.trans()
128        x = np.all(B.data == C.data)
129        y = np.all(B.indices == C.indices)
130        z = np.all(B.indptr == C.indptr)
131        assert_(x*y*z)
132
133    for k in range(50):
134        ra = np.random.randint(2,100)
135        A = rand_unitary(ra,1.0/ra).data
136        B = A.T.tocsr()
137        C = A.trans()
138        x = np.all(B.data == C.data)
139        y = np.all(B.indices == C.indices)
140        z = np.all(B.indptr == C.indptr)
141        assert_(x*y*z)
142
143
144def test_zcsr_adjoint():
145    "spmath: zcsr_adjoint"
146    for k in range(50):
147        ra = np.random.randint(2,100)
148        A = rand_ket(ra,0.5).data
149        B = A.conj().T.tocsr()
150        C = A.adjoint()
151        x = np.all(B.data == C.data)
152        y = np.all(B.indices == C.indices)
153        z = np.all(B.indptr == C.indptr)
154        assert_(x*y*z)
155
156    for k in range(50):
157        ra = np.random.randint(2,100)
158        A = rand_herm(5,1.0/ra).data
159        B = A.conj().T.tocsr()
160        C = A.adjoint()
161        x = np.all(B.data == C.data)
162        y = np.all(B.indices == C.indices)
163        z = np.all(B.indptr == C.indptr)
164        assert_(x*y*z)
165
166    for k in range(50):
167        ra = np.random.randint(2,100)
168        A = rand_dm(ra,1.0/ra).data
169        B = A.conj().T.tocsr()
170        C = A.adjoint()
171        x = np.all(B.data == C.data)
172        y = np.all(B.indices == C.indices)
173        z = np.all(B.indptr == C.indptr)
174        assert_(x*y*z)
175
176    for k in range(50):
177        ra = np.random.randint(2,100)
178        A = rand_unitary(ra,1.0/ra).data
179        B = A.conj().T.tocsr()
180        C = A.adjoint()
181        x = np.all(B.data == C.data)
182        y = np.all(B.indices == C.indices)
183        z = np.all(B.indptr == C.indptr)
184        assert_(x*y*z)
185
186
187def test_zcsr_mult():
188    "spmath: zcsr_mult"
189    for k in range(50):
190        A = rand_ket(10,0.5).data
191        B = rand_herm(10,0.5).data
192
193        C = A.tocsr(1)
194        D = B.tocsr(1)
195
196        ans1 = B*A
197        ans2 = D*C
198        ans2.sort_indices()
199        x = np.all(ans1.data == ans2.data)
200        y = np.all(ans1.indices == ans2.indices)
201        z = np.all(ans1.indptr == ans2.indptr)
202        assert_(x*y*z)
203
204    for k in range(50):
205        A = rand_ket(10,0.5).data
206        B = rand_ket(10,0.5).dag().data
207
208        C = A.tocsr(1)
209        D = B.tocsr(1)
210
211        ans1 = B*A
212        ans2 = D*C
213        ans2.sort_indices()
214        x = np.all(ans1.data == ans2.data)
215        y = np.all(ans1.indices == ans2.indices)
216        z = np.all(ans1.indptr == ans2.indptr)
217        assert_(x*y*z)
218
219        ans1 = A*B
220        ans2 = C*D
221        ans2.sort_indices()
222        x = np.all(ans1.data == ans2.data)
223        y = np.all(ans1.indices == ans2.indices)
224        z = np.all(ans1.indptr == ans2.indptr)
225        assert_(x*y*z)
226
227    for k in range(50):
228        A = rand_dm(10,0.5).data
229        B = rand_dm(10,0.5).data
230
231        C = A.tocsr(1)
232        D = B.tocsr(1)
233
234        ans1 = B*A
235        ans2 = D*C
236        ans2.sort_indices()
237        x = np.all(ans1.data == ans2.data)
238        y = np.all(ans1.indices == ans2.indices)
239        z = np.all(ans1.indptr == ans2.indptr)
240        assert_(x*y*z)
241
242    for k in range(50):
243        A = rand_dm(10,0.5).data
244        B = rand_herm(10,0.5).data
245
246        C = A.tocsr(1)
247        D = B.tocsr(1)
248
249        ans1 = B*A
250        ans2 = D*C
251        ans2.sort_indices()
252        x = np.all(ans1.data == ans2.data)
253        y = np.all(ans1.indices == ans2.indices)
254        z = np.all(ans1.indptr == ans2.indptr)
255        assert_(x*y*z)
256
257
258def test_zcsr_isherm():
259    "spmath: zcsr_isherm"
260    N = 100
261    for kk in range(100):
262        A = rand_herm(N, 0.1)
263        B = rand_herm(N, 0.05) + 1j*rand_herm(N, 0.05)
264        assert_(zcsr_isherm(A.data))
265        assert_(zcsr_isherm(B.data)==0)
266
267
268def test_zcsr_isherm_compare_implicit_zero():
269    """
270    Regression test for gh-1350, comparing explicitly stored values in the
271    matrix (but below the tolerance for allowable Hermicity) to implicit zeros.
272    """
273    tol = 1e-12
274    n = 10
275
276    base = sp.csr_matrix(np.array([[1, tol * 1e-3j], [0, 1]]))
277    base = fast_csr_matrix((base.data, base.indices, base.indptr), base.shape)
278    # If this first line fails, the zero has been stored explicitly and so the
279    # test is invalid.
280    assert base.data.size == 3
281    assert zcsr_isherm(base, tol=tol)
282    assert zcsr_isherm(base.T, tol=tol)
283
284    # A similar test if the structures are different, but it's not
285    # Hermitian.
286    base = sp.csr_matrix(np.array([[1, 1j], [0, 1]]))
287    base = fast_csr_matrix((base.data, base.indices, base.indptr), base.shape)
288    assert base.data.size == 3
289    assert not zcsr_isherm(base, tol=tol)
290    assert not zcsr_isherm(base.T, tol=tol)
291
292    # Catch possible edge case where it shouldn't be Hermitian, but faulty loop
293    # logic doesn't fully compare all rows.
294    base = sp.csr_matrix(np.array([
295        [0, 0, 0, 0],
296        [0, 0, 0, 0],
297        [0, 1, 0, 0],
298        [0, 0, 0, 0],
299    ], dtype=np.complex128))
300    base = fast_csr_matrix((base.data, base.indices, base.indptr), base.shape)
301    assert base.data.size == 1
302    assert not zcsr_isherm(base, tol=tol)
303    assert not zcsr_isherm(base.T, tol=tol)
304
305    # Pure diagonal matrix.
306    base = fast_identity(n)
307    base.data *= np.random.rand(n)
308    assert zcsr_isherm(base, tol=tol)
309    assert not zcsr_isherm(base * 1j, tol=tol)
310
311    # Larger matrices where all off-diagonal elements are below the absolute
312    # tolerance, so everything should always appear Hermitian, but with random
313    # patterns of non-zero elements.  It doesn't matter that it isn't Hermitian
314    # if scaled up; everything is below absolute tolerance, so it should appear
315    # so.  We also set the diagonal to be larger to the tolerance to ensure
316    # isherm can't just compare everything to zero.
317    for density in np.linspace(0.2, 1, 21):
318        base = tol * 1e-2 * (np.random.rand(n, n) + 1j*np.random.rand(n, n))
319        # Mask some values out to zero.
320        base[np.random.rand(n, n) > density] = 0
321        np.fill_diagonal(base, tol * 1000)
322        nnz = np.count_nonzero(base)
323        base = sp.csr_matrix(base)
324        base = fast_csr_matrix((base.data, base.indices, base.indptr), (n, n))
325        assert base.data.size == nnz
326        assert zcsr_isherm(base, tol=tol)
327        assert zcsr_isherm(base.T, tol=tol)
328
329        # Similar test when it must be non-Hermitian.  We set the diagonal to
330        # be real because we want to test off-diagonal implicit zeros, and
331        # having an imaginary first element would automatically fail.
332        nnz = 0
333        while nnz <= n:
334            # Ensure that we don't just have the real diagonal.
335            base = tol * 1000j*np.random.rand(n, n)
336            # Mask some values out to zero.
337            base[np.random.rand(n, n) > density] = 0
338            np.fill_diagonal(base, tol * 1000)
339            nnz = np.count_nonzero(base)
340        base = sp.csr_matrix(base)
341        base = fast_csr_matrix((base.data, base.indices, base.indptr), (n, n))
342        assert base.data.size == nnz
343        assert not zcsr_isherm(base, tol=tol)
344        assert not zcsr_isherm(base.T, tol=tol)
345
346
347if __name__ == "__main__":
348    run_module_suite()
349