1#!python
2#cython: language_level=3
3# This file is part of QuTiP: Quantum Toolbox in Python.
4#
5#    Copyright (c) 2011 and later, The QuTiP Project.
6#    All rights reserved.
7#
8#    Redistribution and use in source and binary forms, with or without
9#    modification, are permitted provided that the following conditions are
10#    met:
11#
12#    1. Redistributions of source code must retain the above copyright notice,
13#       this list of conditions and the following disclaimer.
14#
15#    2. Redistributions in binary form must reproduce the above copyright
16#       notice, this list of conditions and the following disclaimer in the
17#       documentation and/or other materials provided with the distribution.
18#
19#    3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names
20#       of its contributors may be used to endorse or promote products derived
21#       from this software without specific prior written permission.
22#
23#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24#    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25#    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
26#    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27#    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28#    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29#    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30#    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31#    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33#    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34###############################################################################
35import numpy as np
36cimport numpy as cnp
37cimport cython
38from libcpp cimport bool
39from libc.string cimport memset
40
41cdef extern from "<complex>" namespace "std" nogil:
42    double complex conj(double complex x)
43    double         real(double complex)
44    double         imag(double complex)
45    double         abs(double complex)
46
47include "sparse_routines.pxi"
48
49@cython.boundscheck(False)
50@cython.wraparound(False)
51def zcsr_add(complex[::1] dataA, int[::1] indsA, int[::1] indptrA,
52             complex[::1] dataB, int[::1] indsB, int[::1] indptrB,
53             int nrows, int ncols,
54             int Annz, int Bnnz,
55             double complex alpha = 1):
56
57    """
58    Adds two sparse CSR matries. Like SciPy, we assume the worse case
59    for the fill A.nnz + B.nnz.
60    """
61    cdef int worse_fill = Annz + Bnnz
62    cdef int nnz
63    #Both matrices are zero mats
64    if Annz == 0 and Bnnz == 0:
65        return fast_csr_matrix(([], [], []), shape=(nrows,ncols))
66    #A is the zero matrix
67    elif Annz == 0:
68        return fast_csr_matrix((alpha*np.asarray(dataB), indsB, indptrB),
69                            shape=(nrows,ncols))
70    #B is the zero matrix
71    elif Bnnz == 0:
72        return fast_csr_matrix((dataA, indsA, indptrA),
73                            shape=(nrows,ncols))
74    # Out CSR_Matrix
75    cdef CSR_Matrix out
76    init_CSR(&out, worse_fill, nrows, ncols, worse_fill)
77
78    nnz = _zcsr_add_core(&dataA[0], &indsA[0], &indptrA[0],
79                     &dataB[0], &indsB[0], &indptrB[0],
80                     alpha,
81                     &out,
82                     nrows, ncols)
83    #Shorten data and indices if needed
84    if out.nnz > nnz:
85        shorten_CSR(&out, nnz)
86    return CSR_to_scipy(&out)
87
88
89@cython.boundscheck(False)
90@cython.wraparound(False)
91cdef void _zcsr_add(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C, double complex alpha):
92    """
93    Adds two sparse CSR matries. Like SciPy, we assume the worse case
94    for the fill A.nnz + B.nnz.
95    """
96    cdef int worse_fill = A.nnz + B.nnz
97    cdef int nrows = A.nrows
98    cdef int ncols = A.ncols
99    cdef int nnz
100    init_CSR(C, worse_fill, nrows, ncols, worse_fill)
101
102    nnz = _zcsr_add_core(A.data, A.indices, A.indptr,
103                     B.data, B.indices, B.indptr,
104                     alpha, C, nrows, ncols)
105    #Shorten data and indices if needed
106    if C.nnz > nnz:
107        shorten_CSR(C, nnz)
108
109
110
111@cython.boundscheck(False)
112@cython.wraparound(False)
113cdef int _zcsr_add_core(double complex * Adata, int * Aind, int * Aptr,
114                     double complex * Bdata, int * Bind, int * Bptr,
115                     double complex alpha,
116                     CSR_Matrix * C,
117                     int nrows, int ncols) nogil:
118
119    cdef int j1, j2, kc = 0
120    cdef int ka, kb, ka_max, kb_max
121    cdef size_t ii
122    cdef double complex tmp
123    C.indptr[0] = 0
124    if alpha != 1:
125        for ii in range(nrows):
126            ka = Aptr[ii]
127            kb = Bptr[ii]
128            ka_max = Aptr[ii+1]-1
129            kb_max = Bptr[ii+1]-1
130            while (ka <= ka_max) or (kb <= kb_max):
131                if ka <= ka_max:
132                    j1 = Aind[ka]
133                else:
134                    j1 = ncols+1
135
136                if kb <= kb_max:
137                    j2 = Bind[kb]
138                else:
139                    j2 = ncols+1
140
141                if j1 == j2:
142                    tmp = Adata[ka] + alpha*Bdata[kb]
143                    if tmp != 0:
144                        C.data[kc] = tmp
145                        C.indices[kc] = j1
146                        kc += 1
147                    ka += 1
148                    kb += 1
149                elif j1 < j2:
150                    C.data[kc] = Adata[ka]
151                    C.indices[kc] = j1
152                    ka += 1
153                    kc += 1
154                elif j1 > j2:
155                    C.data[kc] = alpha*Bdata[kb]
156                    C.indices[kc] = j2
157                    kb += 1
158                    kc += 1
159
160            C.indptr[ii+1] = kc
161    else:
162        for ii in range(nrows):
163            ka = Aptr[ii]
164            kb = Bptr[ii]
165            ka_max = Aptr[ii+1]-1
166            kb_max = Bptr[ii+1]-1
167            while (ka <= ka_max) or (kb <= kb_max):
168                if ka <= ka_max:
169                    j1 = Aind[ka]
170                else:
171                    j1 = ncols+1
172
173                if kb <= kb_max:
174                    j2 = Bind[kb]
175                else:
176                    j2 = ncols+1
177
178                if j1 == j2:
179                    tmp = Adata[ka] + Bdata[kb]
180                    if tmp != 0:
181                        C.data[kc] = tmp
182                        C.indices[kc] = j1
183                        kc += 1
184                    ka += 1
185                    kb += 1
186                elif j1 < j2:
187                    C.data[kc] = Adata[ka]
188                    C.indices[kc] = j1
189                    ka += 1
190                    kc += 1
191                elif j1 > j2:
192                    C.data[kc] = Bdata[kb]
193                    C.indices[kc] = j2
194                    kb += 1
195                    kc += 1
196
197            C.indptr[ii+1] = kc
198    return kc
199
200
201
202@cython.boundscheck(False)
203@cython.wraparound(False)
204def zcsr_mult(object A, object B, int sorted = 1):
205
206    cdef complex [::1] dataA = A.data
207    cdef int[::1] indsA = A.indices
208    cdef int[::1] indptrA = A.indptr
209    cdef int Annz = A.nnz
210
211    cdef complex [::1] dataB = B.data
212    cdef int[::1] indsB = B.indices
213    cdef int[::1] indptrB = B.indptr
214    cdef int Bnnz = B.nnz
215
216    cdef int nrows = A.shape[0]
217    cdef int ncols = B.shape[1]
218
219    #Both matrices are zero mats
220    if Annz == 0 or Bnnz == 0:
221        return fast_csr_matrix(shape=(nrows,ncols))
222
223    cdef int nnz
224    cdef CSR_Matrix out
225
226    nnz = _zcsr_mult_pass1(&dataA[0], &indsA[0], &indptrA[0],
227                     &dataB[0], &indsB[0], &indptrB[0],
228                     nrows, ncols)
229
230    if nnz == 0:
231        return fast_csr_matrix(shape=(nrows,ncols))
232
233    init_CSR(&out, nnz, nrows, ncols)
234    _zcsr_mult_pass2(&dataA[0], &indsA[0], &indptrA[0],
235                     &dataB[0], &indsB[0], &indptrB[0],
236                     &out,
237                     nrows, ncols)
238
239    #Shorten data and indices if needed
240    if out.nnz > out.indptr[out.nrows]:
241        shorten_CSR(&out, out.indptr[out.nrows])
242
243    if sorted:
244        sort_indices(&out)
245    return CSR_to_scipy(&out)
246
247
248
249@cython.boundscheck(False)
250@cython.wraparound(False)
251cdef void _zcsr_mult(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C):
252
253    nnz = _zcsr_mult_pass1(A.data, A.indices, A.indptr,
254                 B.data, B.indices, B.indptr,
255                 A.nrows, B.ncols)
256
257    init_CSR(C, nnz, A.nrows, B.ncols)
258    _zcsr_mult_pass2(A.data, A.indices, A.indptr,
259                 B.data, B.indices, B.indptr,
260                 C,
261                 A.nrows, B.ncols)
262
263    #Shorten data and indices if needed
264    if C.nnz > C.indptr[C.nrows]:
265        shorten_CSR(C, C.indptr[C.nrows])
266    sort_indices(C)
267
268
269@cython.boundscheck(False)
270@cython.wraparound(False)
271cdef int _zcsr_mult_pass1(double complex * Adata, int * Aind, int * Aptr,
272                     double complex * Bdata, int * Bind, int * Bptr,
273                     int nrows, int ncols) nogil:
274
275    cdef int j, k, nnz = 0
276    cdef size_t ii,jj,kk
277    #Setup mask array
278    cdef int * mask = <int *>PyDataMem_NEW(ncols*sizeof(int))
279    for ii in range(ncols):
280        mask[ii] = -1
281    #Pass 1
282    for ii in range(nrows):
283        for jj in range(Aptr[ii], Aptr[ii+1]):
284            j = Aind[jj]
285            for kk in range(Bptr[j], Bptr[j+1]):
286                k = Bind[kk]
287                if mask[k] != ii:
288                    mask[k] = ii
289                    nnz += 1
290    PyDataMem_FREE(mask)
291    return nnz
292
293
294@cython.boundscheck(False)
295@cython.wraparound(False)
296cdef void _zcsr_mult_pass2(double complex * Adata, int * Aind, int * Aptr,
297                     double complex * Bdata, int * Bind, int * Bptr,
298                     CSR_Matrix * C,
299                     int nrows, int ncols) nogil:
300
301    cdef int head, length, temp, j, k, nnz = 0
302    cdef size_t ii,jj,kk
303    cdef double complex val
304    cdef double complex * sums = <double complex *>PyDataMem_NEW(ncols * sizeof(double complex))
305    cdef int * nxt = <int *>PyDataMem_NEW(ncols*sizeof(int))
306
307    memset(&sums[0],0,ncols * sizeof(double complex))
308    for ii in range(ncols):
309        nxt[ii] = -1
310
311    C.indptr[0] = 0
312    for ii in range(nrows):
313        head = -2
314        length = 0
315        for jj in range(Aptr[ii], Aptr[ii+1]):
316            j = Aind[jj]
317            val = Adata[jj]
318            for kk in range(Bptr[j], Bptr[j+1]):
319                k = Bind[kk]
320                sums[k] += val*Bdata[kk]
321                if nxt[k] == -1:
322                    nxt[k] = head
323                    head = k
324                    length += 1
325
326        for jj in range(length):
327            if sums[head] != 0:
328                C.indices[nnz] = head
329                C.data[nnz] = sums[head]
330                nnz += 1
331            temp = head
332            head = nxt[head]
333            nxt[temp] = -1
334            sums[temp] = 0
335
336        C.indptr[ii+1] = nnz
337
338    #Free temp arrays
339    PyDataMem_FREE(sums)
340    PyDataMem_FREE(nxt)
341
342
343@cython.boundscheck(False)
344@cython.wraparound(False)
345def zcsr_kron(object A, object B):
346    """
347    Computes the kronecker product between two complex
348    sparse matrices in CSR format.
349    """
350    cdef complex[::1] dataA = A.data
351    cdef int[::1] indsA = A.indices
352    cdef int[::1] indptrA = A.indptr
353    cdef int rowsA = A.shape[0]
354    cdef int colsA = A.shape[1]
355
356    cdef complex[::1] dataB = B.data
357    cdef int[::1] indsB = B.indices
358    cdef int[::1] indptrB = B.indptr
359    cdef int rowsB = B.shape[0]
360    cdef int colsB = B.shape[1]
361
362    cdef int out_nnz = _safe_multiply(dataA.shape[0], dataB.shape[0])
363    cdef int rows_out = rowsA * rowsB
364    cdef int cols_out = colsA * colsB
365
366    cdef CSR_Matrix out
367    init_CSR(&out, out_nnz, rows_out, cols_out)
368
369    _zcsr_kron_core(&dataA[0], &indsA[0], &indptrA[0],
370                    &dataB[0], &indsB[0], &indptrB[0],
371                    &out,
372                    rowsA, rowsB, colsB)
373    return CSR_to_scipy(&out)
374
375
376@cython.boundscheck(False)
377@cython.wraparound(False)
378cdef void _zcsr_kron(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C):
379    """
380    Computes the kronecker product between two complex
381    sparse matrices in CSR format.
382    """
383
384    cdef int out_nnz = _safe_multiply(A.nnz, B.nnz)
385    cdef int rows_out = A.nrows * B.nrows
386    cdef int cols_out = A.ncols * B.ncols
387
388    init_CSR(C, out_nnz, rows_out, cols_out)
389
390    _zcsr_kron_core(A.data, A.indices, A.indptr,
391                    B.data, B.indices, B.indptr,
392                    C,
393                    A.nrows, B.nrows, B.ncols)
394
395
396@cython.boundscheck(False)
397@cython.wraparound(False)
398cdef void _zcsr_kron_core(double complex * dataA, int * indsA, int * indptrA,
399                     double complex * dataB, int * indsB, int * indptrB,
400                     CSR_Matrix * out,
401                     int rowsA, int rowsB, int colsB) nogil:
402    cdef size_t ii, jj, ptrA, ptr
403    cdef int row = 0
404    cdef int ptr_start, ptr_end
405    cdef int row_startA, row_endA, row_startB, row_endB, distA, distB, ptrB
406
407    for ii in range(rowsA):
408        row_startA = indptrA[ii]
409        row_endA = indptrA[ii+1]
410        distA = row_endA - row_startA
411
412        for jj in range(rowsB):
413            row_startB = indptrB[jj]
414            row_endB = indptrB[jj+1]
415            distB = row_endB - row_startB
416
417            ptr_start = out.indptr[row]
418            ptr_end = ptr_start + distB
419
420            out.indptr[row+1] = out.indptr[row] + distA * distB
421            row += 1
422
423            for ptrA in range(row_startA, row_endA):
424                ptrB = row_startB
425                for ptr in range(ptr_start, ptr_end):
426                    out.indices[ptr] = indsA[ptrA] * colsB + indsB[ptrB]
427                    out.data[ptr] = dataA[ptrA] * dataB[ptrB]
428                    ptrB += 1
429
430                ptr_start += distB
431                ptr_end += distB
432
433
434@cython.boundscheck(False)
435@cython.wraparound(False)
436def zcsr_transpose(object A):
437    """
438    Transpose of a sparse matrix in CSR format.
439    """
440    cdef complex[::1] data = A.data
441    cdef int[::1] ind = A.indices
442    cdef int[::1] ptr = A.indptr
443    cdef int nrows = A.shape[0]
444    cdef int ncols = A.shape[1]
445
446    cdef CSR_Matrix out
447    init_CSR(&out, data.shape[0], ncols, nrows)
448
449    _zcsr_trans_core(&data[0], &ind[0], &ptr[0],
450                    &out, nrows, ncols)
451    return CSR_to_scipy(&out)
452
453
454@cython.boundscheck(False)
455@cython.wraparound(False)
456cdef void _zcsr_transpose(CSR_Matrix * A, CSR_Matrix * B):
457    """
458    Transpose of a sparse matrix in CSR format.
459    """
460    init_CSR(B, A.nnz, A.ncols, A.nrows)
461
462    _zcsr_trans_core(A.data, A.indices, A.indptr, B, A.nrows, A.ncols)
463
464@cython.boundscheck(False)
465@cython.wraparound(False)
466cdef void _zcsr_trans_core(double complex * data, int * ind, int * ptr,
467                     CSR_Matrix * out,
468                     int nrows, int ncols) nogil:
469
470    cdef int k, nxt
471    cdef size_t ii, jj
472
473    for ii in range(nrows):
474        for jj in range(ptr[ii], ptr[ii+1]):
475            k = ind[jj] + 1
476            out.indptr[k] += 1
477
478    for ii in range(ncols):
479        out.indptr[ii+1] += out.indptr[ii]
480
481    for ii in range(nrows):
482        for jj in range(ptr[ii], ptr[ii+1]):
483            k = ind[jj]
484            nxt = out.indptr[k]
485            out.data[nxt] = data[jj]
486            out.indices[nxt] = ii
487            out.indptr[k] = nxt + 1
488
489    for ii in range(ncols,0,-1):
490        out.indptr[ii] = out.indptr[ii-1]
491
492    out.indptr[0] = 0
493
494
495
496@cython.boundscheck(False)
497@cython.wraparound(False)
498def zcsr_adjoint(object A):
499    """
500    Adjoint of a sparse matrix in CSR format.
501    """
502    cdef complex[::1] data = A.data
503    cdef int[::1] ind = A.indices
504    cdef int[::1] ptr = A.indptr
505    cdef int nrows = A.shape[0]
506    cdef int ncols = A.shape[1]
507
508    cdef CSR_Matrix out
509    init_CSR(&out, data.shape[0], ncols, nrows)
510
511    _zcsr_adjoint_core(&data[0], &ind[0], &ptr[0],
512                        &out, nrows, ncols)
513    return CSR_to_scipy(&out)
514
515
516@cython.boundscheck(False)
517@cython.wraparound(False)
518cdef void _zcsr_adjoint(CSR_Matrix * A, CSR_Matrix * B):
519    """
520    Adjoint of a sparse matrix in CSR format.
521    """
522    init_CSR(B, A.nnz, A.ncols, A.nrows)
523
524    _zcsr_adjoint_core(A.data, A.indices, A.indptr,
525                        B, A.nrows, A.ncols)
526
527
528@cython.boundscheck(False)
529@cython.wraparound(False)
530cdef void _zcsr_adjoint_core(double complex * data, int * ind, int * ptr,
531                     CSR_Matrix * out,
532                     int nrows, int ncols) nogil:
533
534    cdef int k, nxt
535    cdef size_t ii, jj
536
537    for ii in range(nrows):
538        for jj in range(ptr[ii], ptr[ii+1]):
539            k = ind[jj] + 1
540            out.indptr[k] += 1
541
542    for ii in range(ncols):
543        out.indptr[ii+1] += out.indptr[ii]
544
545    for ii in range(nrows):
546        for jj in range(ptr[ii], ptr[ii+1]):
547            k = ind[jj]
548            nxt = out.indptr[k]
549            out.data[nxt] = conj(data[jj])
550            out.indices[nxt] = ii
551            out.indptr[k] = nxt + 1
552
553    for ii in range(ncols,0,-1):
554        out.indptr[ii] = out.indptr[ii-1]
555
556    out.indptr[0] = 0
557
558
559@cython.boundscheck(False)
560@cython.wraparound(False)
561def zcsr_isherm(object A not None, double tol = 1e-12):
562    """
563    Determines if a given input sparse CSR matrix is Hermitian
564    to within a specified floating-point tolerance.
565
566    Parameters
567    ----------
568    A : csr_matrix
569        Input sparse matrix.
570    tol : float (default is atol from settings)
571        Desired tolerance value.
572
573    Returns
574    -------
575    isherm : int
576        One if matrix is Hermitian, zero otherwise.
577
578    Notes
579    -----
580    This implimentation is esentially an adjoint calulation
581    where the data and indices are not stored, but checked
582    elementwise to see if they match those of the input matrix.
583    Thus we do not need to build the actual adjoint.  Here we
584    only need a temp array of output indptr.
585    """
586    cdef complex[::1] data = A.data
587    cdef int[::1] ind = A.indices
588    cdef int[::1] ptr = A.indptr
589    cdef int nrows = A.shape[0]
590    cdef int ncols = A.shape[1]
591
592    cdef int k, nxt, isherm = 1
593    cdef size_t ii, jj
594    cdef complex tmp, tmp2
595
596    if nrows != ncols:
597        return 0
598
599    cdef int * out_ptr = <int *>PyDataMem_NEW( (ncols+1) * sizeof(int))
600
601    memset(&out_ptr[0],0,(ncols+1) * sizeof(int))
602
603    for ii in range(nrows):
604        for jj in range(ptr[ii], ptr[ii+1]):
605            k = ind[jj] + 1
606            out_ptr[k] += 1
607
608    for ii in range(nrows):
609        out_ptr[ii+1] += out_ptr[ii]
610
611    for ii in range(nrows):
612        for jj in range(ptr[ii], ptr[ii+1]):
613            k = ind[jj]
614            nxt = out_ptr[k]
615            out_ptr[k] += 1
616            #structure test
617            if ind[nxt] != ii:
618                isherm = 0
619                break
620            tmp = conj(data[jj])
621            tmp2 = data[nxt]
622            #data test
623            if abs(tmp-tmp2) > tol:
624                isherm = 0
625                break
626        else:
627            continue
628        break
629
630    PyDataMem_FREE(out_ptr)
631    return isherm
632
633@cython.overflowcheck(True)
634cdef _safe_multiply(int A, int B):
635    """
636    Computes A*B and checks for overflow.
637    """
638    cdef int C = A*B
639    return C
640
641
642
643@cython.boundscheck(False)
644@cython.wraparound(False)
645def zcsr_trace(object A, bool isherm):
646    cdef complex[::1] data = A.data
647    cdef int[::1] ind = A.indices
648    cdef int[::1] ptr = A.indptr
649    cdef int nrows = ptr.shape[0]-1
650    cdef size_t ii, jj
651    cdef complex tr = 0
652
653    for ii in range(nrows):
654        for jj in range(ptr[ii], ptr[ii+1]):
655            if ind[jj] == ii:
656                tr += data[jj]
657                break
658    if imag(tr) == 0 or isherm:
659        return real(tr)
660    else:
661        return tr
662
663
664@cython.boundscheck(False)
665@cython.wraparound(False)
666def zcsr_proj(object A, bool is_ket=1):
667    """
668    Computes the projection operator
669    from a given ket or bra vector
670    in CSR format.  The flag 'is_ket'
671    is True if passed a ket.
672
673    This is ~3x faster than doing the
674    conjugate transpose and sparse multiplication
675    directly.  Also, does not need a temp matrix.
676    """
677    cdef complex[::1] data = A.data
678    cdef int[::1] ind = A.indices
679    cdef int[::1] ptr = A.indptr
680    cdef int nrows
681    cdef int nnz
682
683    cdef int offset = 0, new_idx, count, change_idx
684    cdef size_t jj, kk
685
686    if is_ket:
687        nrows = A.shape[0]
688        nnz = ptr[nrows]
689    else:
690        nrows = A.shape[1]
691        nnz = ptr[1]
692
693    cdef CSR_Matrix out
694    init_CSR(&out, nnz**2, nrows)
695
696    if is_ket:
697        #Compute new ptrs and inds
698        for jj in range(nrows):
699            out.indptr[jj] = ptr[jj]*nnz
700            if ptr[jj+1] != ptr[jj]:
701                new_idx = jj
702                for kk in range(nnz):
703                    out.indices[offset+kk*nnz] = new_idx
704                offset += 1
705        #set nnz in new ptr
706        out.indptr[nrows] = nnz**2
707
708        #Compute the data
709        for jj in range(nnz):
710            for kk in range(nnz):
711                out.data[jj*nnz+kk] = data[jj]*conj(data[kk])
712
713    else:
714        count = nnz**2
715        new_idx = nrows
716        for kk in range(nnz-1,-1,-1):
717            for jj in range(nnz-1,-1,-1):
718                out.indices[offset+jj] = ind[jj]
719                out.data[kk*nnz+jj] = conj(data[kk])*data[jj]
720            offset += nnz
721            change_idx = ind[kk]
722            while new_idx > change_idx:
723                out.indptr[new_idx] = count
724                new_idx -= 1
725            count -= nnz
726
727
728    return CSR_to_scipy(&out)
729
730
731
732@cython.boundscheck(False)
733@cython.wraparound(False)
734def zcsr_inner(object A, object B, bool bra_ket):
735    """
736    Computes the inner-product <A|B> between ket-ket,
737    or bra-ket vectors in sparse CSR format.
738    """
739    cdef complex[::1] a_data = A.data
740    cdef int[::1] a_ind = A.indices
741    cdef int[::1] a_ptr = A.indptr
742
743    cdef complex[::1] b_data = B.data
744    cdef int[::1] b_ind = B.indices
745    cdef int[::1] b_ptr = B.indptr
746    cdef int nrows = B.shape[0]
747
748    cdef double complex inner = 0
749    cdef size_t jj, kk
750    cdef int a_idx, b_idx
751
752    if bra_ket:
753        for kk in range(a_ind.shape[0]):
754            a_idx = a_ind[kk]
755            for jj in range(nrows):
756                if (b_ptr[jj+1]-b_ptr[jj]) != 0:
757                    if jj == a_idx:
758                        inner += a_data[kk]*b_data[b_ptr[jj]]
759                        break
760    else:
761        for kk in range(nrows):
762            a_idx = a_ptr[kk]
763            b_idx = b_ptr[kk]
764            if (a_ptr[kk+1]-a_idx) != 0:
765                if (b_ptr[kk+1]-b_idx) != 0:
766                    inner += conj(a_data[a_idx])*b_data[b_idx]
767
768    return inner
769