1#!python 2#cython: language_level=3 3# This file is part of QuTiP: Quantum Toolbox in Python. 4# 5# Copyright (c) 2011 and later, The QuTiP Project. 6# All rights reserved. 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# 1. Redistributions of source code must retain the above copyright notice, 13# this list of conditions and the following disclaimer. 14# 15# 2. Redistributions in binary form must reproduce the above copyright 16# notice, this list of conditions and the following disclaimer in the 17# documentation and/or other materials provided with the distribution. 18# 19# 3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names 20# of its contributors may be used to endorse or promote products derived 21# from this software without specific prior written permission. 22# 23# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 24# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 25# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 26# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 27# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 28# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 29# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 30# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 31# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 32# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34############################################################################### 35import numpy as np 36cimport numpy as cnp 37cimport cython 38from libcpp cimport bool 39from libc.string cimport memset 40 41cdef extern from "<complex>" namespace "std" nogil: 42 double complex conj(double complex x) 43 double real(double complex) 44 double imag(double complex) 45 double abs(double complex) 46 47include "sparse_routines.pxi" 48 49@cython.boundscheck(False) 50@cython.wraparound(False) 51def zcsr_add(complex[::1] dataA, int[::1] indsA, int[::1] indptrA, 52 complex[::1] dataB, int[::1] indsB, int[::1] indptrB, 53 int nrows, int ncols, 54 int Annz, int Bnnz, 55 double complex alpha = 1): 56 57 """ 58 Adds two sparse CSR matries. Like SciPy, we assume the worse case 59 for the fill A.nnz + B.nnz. 60 """ 61 cdef int worse_fill = Annz + Bnnz 62 cdef int nnz 63 #Both matrices are zero mats 64 if Annz == 0 and Bnnz == 0: 65 return fast_csr_matrix(([], [], []), shape=(nrows,ncols)) 66 #A is the zero matrix 67 elif Annz == 0: 68 return fast_csr_matrix((alpha*np.asarray(dataB), indsB, indptrB), 69 shape=(nrows,ncols)) 70 #B is the zero matrix 71 elif Bnnz == 0: 72 return fast_csr_matrix((dataA, indsA, indptrA), 73 shape=(nrows,ncols)) 74 # Out CSR_Matrix 75 cdef CSR_Matrix out 76 init_CSR(&out, worse_fill, nrows, ncols, worse_fill) 77 78 nnz = _zcsr_add_core(&dataA[0], &indsA[0], &indptrA[0], 79 &dataB[0], &indsB[0], &indptrB[0], 80 alpha, 81 &out, 82 nrows, ncols) 83 #Shorten data and indices if needed 84 if out.nnz > nnz: 85 shorten_CSR(&out, nnz) 86 return CSR_to_scipy(&out) 87 88 89@cython.boundscheck(False) 90@cython.wraparound(False) 91cdef void _zcsr_add(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C, double complex alpha): 92 """ 93 Adds two sparse CSR matries. Like SciPy, we assume the worse case 94 for the fill A.nnz + B.nnz. 95 """ 96 cdef int worse_fill = A.nnz + B.nnz 97 cdef int nrows = A.nrows 98 cdef int ncols = A.ncols 99 cdef int nnz 100 init_CSR(C, worse_fill, nrows, ncols, worse_fill) 101 102 nnz = _zcsr_add_core(A.data, A.indices, A.indptr, 103 B.data, B.indices, B.indptr, 104 alpha, C, nrows, ncols) 105 #Shorten data and indices if needed 106 if C.nnz > nnz: 107 shorten_CSR(C, nnz) 108 109 110 111@cython.boundscheck(False) 112@cython.wraparound(False) 113cdef int _zcsr_add_core(double complex * Adata, int * Aind, int * Aptr, 114 double complex * Bdata, int * Bind, int * Bptr, 115 double complex alpha, 116 CSR_Matrix * C, 117 int nrows, int ncols) nogil: 118 119 cdef int j1, j2, kc = 0 120 cdef int ka, kb, ka_max, kb_max 121 cdef size_t ii 122 cdef double complex tmp 123 C.indptr[0] = 0 124 if alpha != 1: 125 for ii in range(nrows): 126 ka = Aptr[ii] 127 kb = Bptr[ii] 128 ka_max = Aptr[ii+1]-1 129 kb_max = Bptr[ii+1]-1 130 while (ka <= ka_max) or (kb <= kb_max): 131 if ka <= ka_max: 132 j1 = Aind[ka] 133 else: 134 j1 = ncols+1 135 136 if kb <= kb_max: 137 j2 = Bind[kb] 138 else: 139 j2 = ncols+1 140 141 if j1 == j2: 142 tmp = Adata[ka] + alpha*Bdata[kb] 143 if tmp != 0: 144 C.data[kc] = tmp 145 C.indices[kc] = j1 146 kc += 1 147 ka += 1 148 kb += 1 149 elif j1 < j2: 150 C.data[kc] = Adata[ka] 151 C.indices[kc] = j1 152 ka += 1 153 kc += 1 154 elif j1 > j2: 155 C.data[kc] = alpha*Bdata[kb] 156 C.indices[kc] = j2 157 kb += 1 158 kc += 1 159 160 C.indptr[ii+1] = kc 161 else: 162 for ii in range(nrows): 163 ka = Aptr[ii] 164 kb = Bptr[ii] 165 ka_max = Aptr[ii+1]-1 166 kb_max = Bptr[ii+1]-1 167 while (ka <= ka_max) or (kb <= kb_max): 168 if ka <= ka_max: 169 j1 = Aind[ka] 170 else: 171 j1 = ncols+1 172 173 if kb <= kb_max: 174 j2 = Bind[kb] 175 else: 176 j2 = ncols+1 177 178 if j1 == j2: 179 tmp = Adata[ka] + Bdata[kb] 180 if tmp != 0: 181 C.data[kc] = tmp 182 C.indices[kc] = j1 183 kc += 1 184 ka += 1 185 kb += 1 186 elif j1 < j2: 187 C.data[kc] = Adata[ka] 188 C.indices[kc] = j1 189 ka += 1 190 kc += 1 191 elif j1 > j2: 192 C.data[kc] = Bdata[kb] 193 C.indices[kc] = j2 194 kb += 1 195 kc += 1 196 197 C.indptr[ii+1] = kc 198 return kc 199 200 201 202@cython.boundscheck(False) 203@cython.wraparound(False) 204def zcsr_mult(object A, object B, int sorted = 1): 205 206 cdef complex [::1] dataA = A.data 207 cdef int[::1] indsA = A.indices 208 cdef int[::1] indptrA = A.indptr 209 cdef int Annz = A.nnz 210 211 cdef complex [::1] dataB = B.data 212 cdef int[::1] indsB = B.indices 213 cdef int[::1] indptrB = B.indptr 214 cdef int Bnnz = B.nnz 215 216 cdef int nrows = A.shape[0] 217 cdef int ncols = B.shape[1] 218 219 #Both matrices are zero mats 220 if Annz == 0 or Bnnz == 0: 221 return fast_csr_matrix(shape=(nrows,ncols)) 222 223 cdef int nnz 224 cdef CSR_Matrix out 225 226 nnz = _zcsr_mult_pass1(&dataA[0], &indsA[0], &indptrA[0], 227 &dataB[0], &indsB[0], &indptrB[0], 228 nrows, ncols) 229 230 if nnz == 0: 231 return fast_csr_matrix(shape=(nrows,ncols)) 232 233 init_CSR(&out, nnz, nrows, ncols) 234 _zcsr_mult_pass2(&dataA[0], &indsA[0], &indptrA[0], 235 &dataB[0], &indsB[0], &indptrB[0], 236 &out, 237 nrows, ncols) 238 239 #Shorten data and indices if needed 240 if out.nnz > out.indptr[out.nrows]: 241 shorten_CSR(&out, out.indptr[out.nrows]) 242 243 if sorted: 244 sort_indices(&out) 245 return CSR_to_scipy(&out) 246 247 248 249@cython.boundscheck(False) 250@cython.wraparound(False) 251cdef void _zcsr_mult(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C): 252 253 nnz = _zcsr_mult_pass1(A.data, A.indices, A.indptr, 254 B.data, B.indices, B.indptr, 255 A.nrows, B.ncols) 256 257 init_CSR(C, nnz, A.nrows, B.ncols) 258 _zcsr_mult_pass2(A.data, A.indices, A.indptr, 259 B.data, B.indices, B.indptr, 260 C, 261 A.nrows, B.ncols) 262 263 #Shorten data and indices if needed 264 if C.nnz > C.indptr[C.nrows]: 265 shorten_CSR(C, C.indptr[C.nrows]) 266 sort_indices(C) 267 268 269@cython.boundscheck(False) 270@cython.wraparound(False) 271cdef int _zcsr_mult_pass1(double complex * Adata, int * Aind, int * Aptr, 272 double complex * Bdata, int * Bind, int * Bptr, 273 int nrows, int ncols) nogil: 274 275 cdef int j, k, nnz = 0 276 cdef size_t ii,jj,kk 277 #Setup mask array 278 cdef int * mask = <int *>PyDataMem_NEW(ncols*sizeof(int)) 279 for ii in range(ncols): 280 mask[ii] = -1 281 #Pass 1 282 for ii in range(nrows): 283 for jj in range(Aptr[ii], Aptr[ii+1]): 284 j = Aind[jj] 285 for kk in range(Bptr[j], Bptr[j+1]): 286 k = Bind[kk] 287 if mask[k] != ii: 288 mask[k] = ii 289 nnz += 1 290 PyDataMem_FREE(mask) 291 return nnz 292 293 294@cython.boundscheck(False) 295@cython.wraparound(False) 296cdef void _zcsr_mult_pass2(double complex * Adata, int * Aind, int * Aptr, 297 double complex * Bdata, int * Bind, int * Bptr, 298 CSR_Matrix * C, 299 int nrows, int ncols) nogil: 300 301 cdef int head, length, temp, j, k, nnz = 0 302 cdef size_t ii,jj,kk 303 cdef double complex val 304 cdef double complex * sums = <double complex *>PyDataMem_NEW(ncols * sizeof(double complex)) 305 cdef int * nxt = <int *>PyDataMem_NEW(ncols*sizeof(int)) 306 307 memset(&sums[0],0,ncols * sizeof(double complex)) 308 for ii in range(ncols): 309 nxt[ii] = -1 310 311 C.indptr[0] = 0 312 for ii in range(nrows): 313 head = -2 314 length = 0 315 for jj in range(Aptr[ii], Aptr[ii+1]): 316 j = Aind[jj] 317 val = Adata[jj] 318 for kk in range(Bptr[j], Bptr[j+1]): 319 k = Bind[kk] 320 sums[k] += val*Bdata[kk] 321 if nxt[k] == -1: 322 nxt[k] = head 323 head = k 324 length += 1 325 326 for jj in range(length): 327 if sums[head] != 0: 328 C.indices[nnz] = head 329 C.data[nnz] = sums[head] 330 nnz += 1 331 temp = head 332 head = nxt[head] 333 nxt[temp] = -1 334 sums[temp] = 0 335 336 C.indptr[ii+1] = nnz 337 338 #Free temp arrays 339 PyDataMem_FREE(sums) 340 PyDataMem_FREE(nxt) 341 342 343@cython.boundscheck(False) 344@cython.wraparound(False) 345def zcsr_kron(object A, object B): 346 """ 347 Computes the kronecker product between two complex 348 sparse matrices in CSR format. 349 """ 350 cdef complex[::1] dataA = A.data 351 cdef int[::1] indsA = A.indices 352 cdef int[::1] indptrA = A.indptr 353 cdef int rowsA = A.shape[0] 354 cdef int colsA = A.shape[1] 355 356 cdef complex[::1] dataB = B.data 357 cdef int[::1] indsB = B.indices 358 cdef int[::1] indptrB = B.indptr 359 cdef int rowsB = B.shape[0] 360 cdef int colsB = B.shape[1] 361 362 cdef int out_nnz = _safe_multiply(dataA.shape[0], dataB.shape[0]) 363 cdef int rows_out = rowsA * rowsB 364 cdef int cols_out = colsA * colsB 365 366 cdef CSR_Matrix out 367 init_CSR(&out, out_nnz, rows_out, cols_out) 368 369 _zcsr_kron_core(&dataA[0], &indsA[0], &indptrA[0], 370 &dataB[0], &indsB[0], &indptrB[0], 371 &out, 372 rowsA, rowsB, colsB) 373 return CSR_to_scipy(&out) 374 375 376@cython.boundscheck(False) 377@cython.wraparound(False) 378cdef void _zcsr_kron(CSR_Matrix * A, CSR_Matrix * B, CSR_Matrix * C): 379 """ 380 Computes the kronecker product between two complex 381 sparse matrices in CSR format. 382 """ 383 384 cdef int out_nnz = _safe_multiply(A.nnz, B.nnz) 385 cdef int rows_out = A.nrows * B.nrows 386 cdef int cols_out = A.ncols * B.ncols 387 388 init_CSR(C, out_nnz, rows_out, cols_out) 389 390 _zcsr_kron_core(A.data, A.indices, A.indptr, 391 B.data, B.indices, B.indptr, 392 C, 393 A.nrows, B.nrows, B.ncols) 394 395 396@cython.boundscheck(False) 397@cython.wraparound(False) 398cdef void _zcsr_kron_core(double complex * dataA, int * indsA, int * indptrA, 399 double complex * dataB, int * indsB, int * indptrB, 400 CSR_Matrix * out, 401 int rowsA, int rowsB, int colsB) nogil: 402 cdef size_t ii, jj, ptrA, ptr 403 cdef int row = 0 404 cdef int ptr_start, ptr_end 405 cdef int row_startA, row_endA, row_startB, row_endB, distA, distB, ptrB 406 407 for ii in range(rowsA): 408 row_startA = indptrA[ii] 409 row_endA = indptrA[ii+1] 410 distA = row_endA - row_startA 411 412 for jj in range(rowsB): 413 row_startB = indptrB[jj] 414 row_endB = indptrB[jj+1] 415 distB = row_endB - row_startB 416 417 ptr_start = out.indptr[row] 418 ptr_end = ptr_start + distB 419 420 out.indptr[row+1] = out.indptr[row] + distA * distB 421 row += 1 422 423 for ptrA in range(row_startA, row_endA): 424 ptrB = row_startB 425 for ptr in range(ptr_start, ptr_end): 426 out.indices[ptr] = indsA[ptrA] * colsB + indsB[ptrB] 427 out.data[ptr] = dataA[ptrA] * dataB[ptrB] 428 ptrB += 1 429 430 ptr_start += distB 431 ptr_end += distB 432 433 434@cython.boundscheck(False) 435@cython.wraparound(False) 436def zcsr_transpose(object A): 437 """ 438 Transpose of a sparse matrix in CSR format. 439 """ 440 cdef complex[::1] data = A.data 441 cdef int[::1] ind = A.indices 442 cdef int[::1] ptr = A.indptr 443 cdef int nrows = A.shape[0] 444 cdef int ncols = A.shape[1] 445 446 cdef CSR_Matrix out 447 init_CSR(&out, data.shape[0], ncols, nrows) 448 449 _zcsr_trans_core(&data[0], &ind[0], &ptr[0], 450 &out, nrows, ncols) 451 return CSR_to_scipy(&out) 452 453 454@cython.boundscheck(False) 455@cython.wraparound(False) 456cdef void _zcsr_transpose(CSR_Matrix * A, CSR_Matrix * B): 457 """ 458 Transpose of a sparse matrix in CSR format. 459 """ 460 init_CSR(B, A.nnz, A.ncols, A.nrows) 461 462 _zcsr_trans_core(A.data, A.indices, A.indptr, B, A.nrows, A.ncols) 463 464@cython.boundscheck(False) 465@cython.wraparound(False) 466cdef void _zcsr_trans_core(double complex * data, int * ind, int * ptr, 467 CSR_Matrix * out, 468 int nrows, int ncols) nogil: 469 470 cdef int k, nxt 471 cdef size_t ii, jj 472 473 for ii in range(nrows): 474 for jj in range(ptr[ii], ptr[ii+1]): 475 k = ind[jj] + 1 476 out.indptr[k] += 1 477 478 for ii in range(ncols): 479 out.indptr[ii+1] += out.indptr[ii] 480 481 for ii in range(nrows): 482 for jj in range(ptr[ii], ptr[ii+1]): 483 k = ind[jj] 484 nxt = out.indptr[k] 485 out.data[nxt] = data[jj] 486 out.indices[nxt] = ii 487 out.indptr[k] = nxt + 1 488 489 for ii in range(ncols,0,-1): 490 out.indptr[ii] = out.indptr[ii-1] 491 492 out.indptr[0] = 0 493 494 495 496@cython.boundscheck(False) 497@cython.wraparound(False) 498def zcsr_adjoint(object A): 499 """ 500 Adjoint of a sparse matrix in CSR format. 501 """ 502 cdef complex[::1] data = A.data 503 cdef int[::1] ind = A.indices 504 cdef int[::1] ptr = A.indptr 505 cdef int nrows = A.shape[0] 506 cdef int ncols = A.shape[1] 507 508 cdef CSR_Matrix out 509 init_CSR(&out, data.shape[0], ncols, nrows) 510 511 _zcsr_adjoint_core(&data[0], &ind[0], &ptr[0], 512 &out, nrows, ncols) 513 return CSR_to_scipy(&out) 514 515 516@cython.boundscheck(False) 517@cython.wraparound(False) 518cdef void _zcsr_adjoint(CSR_Matrix * A, CSR_Matrix * B): 519 """ 520 Adjoint of a sparse matrix in CSR format. 521 """ 522 init_CSR(B, A.nnz, A.ncols, A.nrows) 523 524 _zcsr_adjoint_core(A.data, A.indices, A.indptr, 525 B, A.nrows, A.ncols) 526 527 528@cython.boundscheck(False) 529@cython.wraparound(False) 530cdef void _zcsr_adjoint_core(double complex * data, int * ind, int * ptr, 531 CSR_Matrix * out, 532 int nrows, int ncols) nogil: 533 534 cdef int k, nxt 535 cdef size_t ii, jj 536 537 for ii in range(nrows): 538 for jj in range(ptr[ii], ptr[ii+1]): 539 k = ind[jj] + 1 540 out.indptr[k] += 1 541 542 for ii in range(ncols): 543 out.indptr[ii+1] += out.indptr[ii] 544 545 for ii in range(nrows): 546 for jj in range(ptr[ii], ptr[ii+1]): 547 k = ind[jj] 548 nxt = out.indptr[k] 549 out.data[nxt] = conj(data[jj]) 550 out.indices[nxt] = ii 551 out.indptr[k] = nxt + 1 552 553 for ii in range(ncols,0,-1): 554 out.indptr[ii] = out.indptr[ii-1] 555 556 out.indptr[0] = 0 557 558 559@cython.boundscheck(False) 560@cython.wraparound(False) 561def zcsr_isherm(object A not None, double tol = 1e-12): 562 """ 563 Determines if a given input sparse CSR matrix is Hermitian 564 to within a specified floating-point tolerance. 565 566 Parameters 567 ---------- 568 A : csr_matrix 569 Input sparse matrix. 570 tol : float (default is atol from settings) 571 Desired tolerance value. 572 573 Returns 574 ------- 575 isherm : int 576 One if matrix is Hermitian, zero otherwise. 577 578 Notes 579 ----- 580 This implimentation is esentially an adjoint calulation 581 where the data and indices are not stored, but checked 582 elementwise to see if they match those of the input matrix. 583 Thus we do not need to build the actual adjoint. Here we 584 only need a temp array of output indptr. 585 """ 586 cdef complex[::1] data = A.data 587 cdef int[::1] ind = A.indices 588 cdef int[::1] ptr = A.indptr 589 cdef int nrows = A.shape[0] 590 cdef int ncols = A.shape[1] 591 592 cdef int k, nxt, isherm = 1 593 cdef size_t ii, jj 594 cdef complex tmp, tmp2 595 596 if nrows != ncols: 597 return 0 598 599 cdef int * out_ptr = <int *>PyDataMem_NEW( (ncols+1) * sizeof(int)) 600 601 memset(&out_ptr[0],0,(ncols+1) * sizeof(int)) 602 603 for ii in range(nrows): 604 for jj in range(ptr[ii], ptr[ii+1]): 605 k = ind[jj] + 1 606 out_ptr[k] += 1 607 608 for ii in range(nrows): 609 out_ptr[ii+1] += out_ptr[ii] 610 611 for ii in range(nrows): 612 for jj in range(ptr[ii], ptr[ii+1]): 613 k = ind[jj] 614 nxt = out_ptr[k] 615 out_ptr[k] += 1 616 #structure test 617 if ind[nxt] != ii: 618 isherm = 0 619 break 620 tmp = conj(data[jj]) 621 tmp2 = data[nxt] 622 #data test 623 if abs(tmp-tmp2) > tol: 624 isherm = 0 625 break 626 else: 627 continue 628 break 629 630 PyDataMem_FREE(out_ptr) 631 return isherm 632 633@cython.overflowcheck(True) 634cdef _safe_multiply(int A, int B): 635 """ 636 Computes A*B and checks for overflow. 637 """ 638 cdef int C = A*B 639 return C 640 641 642 643@cython.boundscheck(False) 644@cython.wraparound(False) 645def zcsr_trace(object A, bool isherm): 646 cdef complex[::1] data = A.data 647 cdef int[::1] ind = A.indices 648 cdef int[::1] ptr = A.indptr 649 cdef int nrows = ptr.shape[0]-1 650 cdef size_t ii, jj 651 cdef complex tr = 0 652 653 for ii in range(nrows): 654 for jj in range(ptr[ii], ptr[ii+1]): 655 if ind[jj] == ii: 656 tr += data[jj] 657 break 658 if imag(tr) == 0 or isherm: 659 return real(tr) 660 else: 661 return tr 662 663 664@cython.boundscheck(False) 665@cython.wraparound(False) 666def zcsr_proj(object A, bool is_ket=1): 667 """ 668 Computes the projection operator 669 from a given ket or bra vector 670 in CSR format. The flag 'is_ket' 671 is True if passed a ket. 672 673 This is ~3x faster than doing the 674 conjugate transpose and sparse multiplication 675 directly. Also, does not need a temp matrix. 676 """ 677 cdef complex[::1] data = A.data 678 cdef int[::1] ind = A.indices 679 cdef int[::1] ptr = A.indptr 680 cdef int nrows 681 cdef int nnz 682 683 cdef int offset = 0, new_idx, count, change_idx 684 cdef size_t jj, kk 685 686 if is_ket: 687 nrows = A.shape[0] 688 nnz = ptr[nrows] 689 else: 690 nrows = A.shape[1] 691 nnz = ptr[1] 692 693 cdef CSR_Matrix out 694 init_CSR(&out, nnz**2, nrows) 695 696 if is_ket: 697 #Compute new ptrs and inds 698 for jj in range(nrows): 699 out.indptr[jj] = ptr[jj]*nnz 700 if ptr[jj+1] != ptr[jj]: 701 new_idx = jj 702 for kk in range(nnz): 703 out.indices[offset+kk*nnz] = new_idx 704 offset += 1 705 #set nnz in new ptr 706 out.indptr[nrows] = nnz**2 707 708 #Compute the data 709 for jj in range(nnz): 710 for kk in range(nnz): 711 out.data[jj*nnz+kk] = data[jj]*conj(data[kk]) 712 713 else: 714 count = nnz**2 715 new_idx = nrows 716 for kk in range(nnz-1,-1,-1): 717 for jj in range(nnz-1,-1,-1): 718 out.indices[offset+jj] = ind[jj] 719 out.data[kk*nnz+jj] = conj(data[kk])*data[jj] 720 offset += nnz 721 change_idx = ind[kk] 722 while new_idx > change_idx: 723 out.indptr[new_idx] = count 724 new_idx -= 1 725 count -= nnz 726 727 728 return CSR_to_scipy(&out) 729 730 731 732@cython.boundscheck(False) 733@cython.wraparound(False) 734def zcsr_inner(object A, object B, bool bra_ket): 735 """ 736 Computes the inner-product <A|B> between ket-ket, 737 or bra-ket vectors in sparse CSR format. 738 """ 739 cdef complex[::1] a_data = A.data 740 cdef int[::1] a_ind = A.indices 741 cdef int[::1] a_ptr = A.indptr 742 743 cdef complex[::1] b_data = B.data 744 cdef int[::1] b_ind = B.indices 745 cdef int[::1] b_ptr = B.indptr 746 cdef int nrows = B.shape[0] 747 748 cdef double complex inner = 0 749 cdef size_t jj, kk 750 cdef int a_idx, b_idx 751 752 if bra_ket: 753 for kk in range(a_ind.shape[0]): 754 a_idx = a_ind[kk] 755 for jj in range(nrows): 756 if (b_ptr[jj+1]-b_ptr[jj]) != 0: 757 if jj == a_idx: 758 inner += a_data[kk]*b_data[b_ptr[jj]] 759 break 760 else: 761 for kk in range(nrows): 762 a_idx = a_ptr[kk] 763 b_idx = b_ptr[kk] 764 if (a_ptr[kk+1]-a_idx) != 0: 765 if (b_ptr[kk+1]-b_idx) != 0: 766 inner += conj(a_data[a_idx])*b_data[b_idx] 767 768 return inner 769