1// Copyright ©2019 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package gonum 6 7import ( 8 "math/cmplx" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/internal/asm/c128" 12) 13 14var _ blas.Complex128Level3 = Implementation{} 15 16// Zgemm performs one of the matrix-matrix operations 17// C = alpha * op(A) * op(B) + beta * C 18// where op(X) is one of 19// op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ, 20// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix, 21// op(B) a k×n matrix and C an m×n matrix. 22func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 23 switch tA { 24 default: 25 panic(badTranspose) 26 case blas.NoTrans, blas.Trans, blas.ConjTrans: 27 } 28 switch tB { 29 default: 30 panic(badTranspose) 31 case blas.NoTrans, blas.Trans, blas.ConjTrans: 32 } 33 switch { 34 case m < 0: 35 panic(mLT0) 36 case n < 0: 37 panic(nLT0) 38 case k < 0: 39 panic(kLT0) 40 } 41 rowA, colA := m, k 42 if tA != blas.NoTrans { 43 rowA, colA = k, m 44 } 45 if lda < max(1, colA) { 46 panic(badLdA) 47 } 48 rowB, colB := k, n 49 if tB != blas.NoTrans { 50 rowB, colB = n, k 51 } 52 if ldb < max(1, colB) { 53 panic(badLdB) 54 } 55 if ldc < max(1, n) { 56 panic(badLdC) 57 } 58 59 // Quick return if possible. 60 if m == 0 || n == 0 { 61 return 62 } 63 64 // For zero matrix size the following slice length checks are trivially satisfied. 65 if len(a) < (rowA-1)*lda+colA { 66 panic(shortA) 67 } 68 if len(b) < (rowB-1)*ldb+colB { 69 panic(shortB) 70 } 71 if len(c) < (m-1)*ldc+n { 72 panic(shortC) 73 } 74 75 // Quick return if possible. 76 if (alpha == 0 || k == 0) && beta == 1 { 77 return 78 } 79 80 if alpha == 0 { 81 if beta == 0 { 82 for i := 0; i < m; i++ { 83 for j := 0; j < n; j++ { 84 c[i*ldc+j] = 0 85 } 86 } 87 } else { 88 for i := 0; i < m; i++ { 89 for j := 0; j < n; j++ { 90 c[i*ldc+j] *= beta 91 } 92 } 93 } 94 return 95 } 96 97 switch tA { 98 case blas.NoTrans: 99 switch tB { 100 case blas.NoTrans: 101 // Form C = alpha * A * B + beta * C. 102 for i := 0; i < m; i++ { 103 switch { 104 case beta == 0: 105 for j := 0; j < n; j++ { 106 c[i*ldc+j] = 0 107 } 108 case beta != 1: 109 for j := 0; j < n; j++ { 110 c[i*ldc+j] *= beta 111 } 112 } 113 for l := 0; l < k; l++ { 114 tmp := alpha * a[i*lda+l] 115 for j := 0; j < n; j++ { 116 c[i*ldc+j] += tmp * b[l*ldb+j] 117 } 118 } 119 } 120 case blas.Trans: 121 // Form C = alpha * A * Bᵀ + beta * C. 122 for i := 0; i < m; i++ { 123 switch { 124 case beta == 0: 125 for j := 0; j < n; j++ { 126 c[i*ldc+j] = 0 127 } 128 case beta != 1: 129 for j := 0; j < n; j++ { 130 c[i*ldc+j] *= beta 131 } 132 } 133 for l := 0; l < k; l++ { 134 tmp := alpha * a[i*lda+l] 135 for j := 0; j < n; j++ { 136 c[i*ldc+j] += tmp * b[j*ldb+l] 137 } 138 } 139 } 140 case blas.ConjTrans: 141 // Form C = alpha * A * Bᴴ + beta * C. 142 for i := 0; i < m; i++ { 143 switch { 144 case beta == 0: 145 for j := 0; j < n; j++ { 146 c[i*ldc+j] = 0 147 } 148 case beta != 1: 149 for j := 0; j < n; j++ { 150 c[i*ldc+j] *= beta 151 } 152 } 153 for l := 0; l < k; l++ { 154 tmp := alpha * a[i*lda+l] 155 for j := 0; j < n; j++ { 156 c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l]) 157 } 158 } 159 } 160 } 161 case blas.Trans: 162 switch tB { 163 case blas.NoTrans: 164 // Form C = alpha * Aᵀ * B + beta * C. 165 for i := 0; i < m; i++ { 166 for j := 0; j < n; j++ { 167 var tmp complex128 168 for l := 0; l < k; l++ { 169 tmp += a[l*lda+i] * b[l*ldb+j] 170 } 171 if beta == 0 { 172 c[i*ldc+j] = alpha * tmp 173 } else { 174 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 175 } 176 } 177 } 178 case blas.Trans: 179 // Form C = alpha * Aᵀ * Bᵀ + beta * C. 180 for i := 0; i < m; i++ { 181 for j := 0; j < n; j++ { 182 var tmp complex128 183 for l := 0; l < k; l++ { 184 tmp += a[l*lda+i] * b[j*ldb+l] 185 } 186 if beta == 0 { 187 c[i*ldc+j] = alpha * tmp 188 } else { 189 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 190 } 191 } 192 } 193 case blas.ConjTrans: 194 // Form C = alpha * Aᵀ * Bᴴ + beta * C. 195 for i := 0; i < m; i++ { 196 for j := 0; j < n; j++ { 197 var tmp complex128 198 for l := 0; l < k; l++ { 199 tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l]) 200 } 201 if beta == 0 { 202 c[i*ldc+j] = alpha * tmp 203 } else { 204 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 205 } 206 } 207 } 208 } 209 case blas.ConjTrans: 210 switch tB { 211 case blas.NoTrans: 212 // Form C = alpha * Aᴴ * B + beta * C. 213 for i := 0; i < m; i++ { 214 for j := 0; j < n; j++ { 215 var tmp complex128 216 for l := 0; l < k; l++ { 217 tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j] 218 } 219 if beta == 0 { 220 c[i*ldc+j] = alpha * tmp 221 } else { 222 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 223 } 224 } 225 } 226 case blas.Trans: 227 // Form C = alpha * Aᴴ * Bᵀ + beta * C. 228 for i := 0; i < m; i++ { 229 for j := 0; j < n; j++ { 230 var tmp complex128 231 for l := 0; l < k; l++ { 232 tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l] 233 } 234 if beta == 0 { 235 c[i*ldc+j] = alpha * tmp 236 } else { 237 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 238 } 239 } 240 } 241 case blas.ConjTrans: 242 // Form C = alpha * Aᴴ * Bᴴ + beta * C. 243 for i := 0; i < m; i++ { 244 for j := 0; j < n; j++ { 245 var tmp complex128 246 for l := 0; l < k; l++ { 247 tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l]) 248 } 249 if beta == 0 { 250 c[i*ldc+j] = alpha * tmp 251 } else { 252 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 253 } 254 } 255 } 256 } 257 } 258} 259 260// Zhemm performs one of the matrix-matrix operations 261// C = alpha*A*B + beta*C if side == blas.Left 262// C = alpha*B*A + beta*C if side == blas.Right 263// where alpha and beta are scalars, A is an m×m or n×n hermitian matrix and B 264// and C are m×n matrices. The imaginary parts of the diagonal elements of A are 265// assumed to be zero. 266func (Implementation) Zhemm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 267 na := m 268 if side == blas.Right { 269 na = n 270 } 271 switch { 272 case side != blas.Left && side != blas.Right: 273 panic(badSide) 274 case uplo != blas.Lower && uplo != blas.Upper: 275 panic(badUplo) 276 case m < 0: 277 panic(mLT0) 278 case n < 0: 279 panic(nLT0) 280 case lda < max(1, na): 281 panic(badLdA) 282 case ldb < max(1, n): 283 panic(badLdB) 284 case ldc < max(1, n): 285 panic(badLdC) 286 } 287 288 // Quick return if possible. 289 if m == 0 || n == 0 { 290 return 291 } 292 293 // For zero matrix size the following slice length checks are trivially satisfied. 294 if len(a) < lda*(na-1)+na { 295 panic(shortA) 296 } 297 if len(b) < ldb*(m-1)+n { 298 panic(shortB) 299 } 300 if len(c) < ldc*(m-1)+n { 301 panic(shortC) 302 } 303 304 // Quick return if possible. 305 if alpha == 0 && beta == 1 { 306 return 307 } 308 309 if alpha == 0 { 310 if beta == 0 { 311 for i := 0; i < m; i++ { 312 ci := c[i*ldc : i*ldc+n] 313 for j := range ci { 314 ci[j] = 0 315 } 316 } 317 } else { 318 for i := 0; i < m; i++ { 319 ci := c[i*ldc : i*ldc+n] 320 c128.ScalUnitary(beta, ci) 321 } 322 } 323 return 324 } 325 326 if side == blas.Left { 327 // Form C = alpha*A*B + beta*C. 328 for i := 0; i < m; i++ { 329 atmp := alpha * complex(real(a[i*lda+i]), 0) 330 bi := b[i*ldb : i*ldb+n] 331 ci := c[i*ldc : i*ldc+n] 332 if beta == 0 { 333 for j, bij := range bi { 334 ci[j] = atmp * bij 335 } 336 } else { 337 for j, bij := range bi { 338 ci[j] = atmp*bij + beta*ci[j] 339 } 340 } 341 if uplo == blas.Upper { 342 for k := 0; k < i; k++ { 343 atmp = alpha * cmplx.Conj(a[k*lda+i]) 344 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 345 } 346 for k := i + 1; k < m; k++ { 347 atmp = alpha * a[i*lda+k] 348 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 349 } 350 } else { 351 for k := 0; k < i; k++ { 352 atmp = alpha * a[i*lda+k] 353 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 354 } 355 for k := i + 1; k < m; k++ { 356 atmp = alpha * cmplx.Conj(a[k*lda+i]) 357 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 358 } 359 } 360 } 361 } else { 362 // Form C = alpha*B*A + beta*C. 363 if uplo == blas.Upper { 364 for i := 0; i < m; i++ { 365 for j := n - 1; j >= 0; j-- { 366 abij := alpha * b[i*ldb+j] 367 aj := a[j*lda+j+1 : j*lda+n] 368 bi := b[i*ldb+j+1 : i*ldb+n] 369 ci := c[i*ldc+j+1 : i*ldc+n] 370 var tmp complex128 371 for k, ajk := range aj { 372 ci[k] += abij * ajk 373 tmp += bi[k] * cmplx.Conj(ajk) 374 } 375 ajj := complex(real(a[j*lda+j]), 0) 376 if beta == 0 { 377 c[i*ldc+j] = abij*ajj + alpha*tmp 378 } else { 379 c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j] 380 } 381 } 382 } 383 } else { 384 for i := 0; i < m; i++ { 385 for j := 0; j < n; j++ { 386 abij := alpha * b[i*ldb+j] 387 aj := a[j*lda : j*lda+j] 388 bi := b[i*ldb : i*ldb+j] 389 ci := c[i*ldc : i*ldc+j] 390 var tmp complex128 391 for k, ajk := range aj { 392 ci[k] += abij * ajk 393 tmp += bi[k] * cmplx.Conj(ajk) 394 } 395 ajj := complex(real(a[j*lda+j]), 0) 396 if beta == 0 { 397 c[i*ldc+j] = abij*ajj + alpha*tmp 398 } else { 399 c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j] 400 } 401 } 402 } 403 } 404 } 405} 406 407// Zherk performs one of the hermitian rank-k operations 408// C = alpha*A*Aᴴ + beta*C if trans == blas.NoTrans 409// C = alpha*Aᴴ*A + beta*C if trans == blas.ConjTrans 410// where alpha and beta are real scalars, C is an n×n hermitian matrix and A is 411// an n×k matrix in the first case and a k×n matrix in the second case. 412// 413// The imaginary parts of the diagonal elements of C are assumed to be zero, and 414// on return they will be set to zero. 415func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) { 416 var rowA, colA int 417 switch trans { 418 default: 419 panic(badTranspose) 420 case blas.NoTrans: 421 rowA, colA = n, k 422 case blas.ConjTrans: 423 rowA, colA = k, n 424 } 425 switch { 426 case uplo != blas.Lower && uplo != blas.Upper: 427 panic(badUplo) 428 case n < 0: 429 panic(nLT0) 430 case k < 0: 431 panic(kLT0) 432 case lda < max(1, colA): 433 panic(badLdA) 434 case ldc < max(1, n): 435 panic(badLdC) 436 } 437 438 // Quick return if possible. 439 if n == 0 { 440 return 441 } 442 443 // For zero matrix size the following slice length checks are trivially satisfied. 444 if len(a) < (rowA-1)*lda+colA { 445 panic(shortA) 446 } 447 if len(c) < (n-1)*ldc+n { 448 panic(shortC) 449 } 450 451 // Quick return if possible. 452 if (alpha == 0 || k == 0) && beta == 1 { 453 return 454 } 455 456 if alpha == 0 { 457 if uplo == blas.Upper { 458 if beta == 0 { 459 for i := 0; i < n; i++ { 460 ci := c[i*ldc+i : i*ldc+n] 461 for j := range ci { 462 ci[j] = 0 463 } 464 } 465 } else { 466 for i := 0; i < n; i++ { 467 ci := c[i*ldc+i : i*ldc+n] 468 ci[0] = complex(beta*real(ci[0]), 0) 469 if i != n-1 { 470 c128.DscalUnitary(beta, ci[1:]) 471 } 472 } 473 } 474 } else { 475 if beta == 0 { 476 for i := 0; i < n; i++ { 477 ci := c[i*ldc : i*ldc+i+1] 478 for j := range ci { 479 ci[j] = 0 480 } 481 } 482 } else { 483 for i := 0; i < n; i++ { 484 ci := c[i*ldc : i*ldc+i+1] 485 if i != 0 { 486 c128.DscalUnitary(beta, ci[:i]) 487 } 488 ci[i] = complex(beta*real(ci[i]), 0) 489 } 490 } 491 } 492 return 493 } 494 495 calpha := complex(alpha, 0) 496 if trans == blas.NoTrans { 497 // Form C = alpha*A*Aᴴ + beta*C. 498 cbeta := complex(beta, 0) 499 if uplo == blas.Upper { 500 for i := 0; i < n; i++ { 501 ci := c[i*ldc+i : i*ldc+n] 502 ai := a[i*lda : i*lda+k] 503 switch { 504 case beta == 0: 505 // Handle the i-th diagonal element of C. 506 ci[0] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0) 507 // Handle the remaining elements on the i-th row of C. 508 for jc := range ci[1:] { 509 j := i + 1 + jc 510 ci[jc+1] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai) 511 } 512 case beta != 1: 513 cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[0] 514 ci[0] = complex(real(cii), 0) 515 for jc, cij := range ci[1:] { 516 j := i + 1 + jc 517 ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij 518 } 519 default: 520 cii := calpha*c128.DotcUnitary(ai, ai) + ci[0] 521 ci[0] = complex(real(cii), 0) 522 for jc, cij := range ci[1:] { 523 j := i + 1 + jc 524 ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij 525 } 526 } 527 } 528 } else { 529 for i := 0; i < n; i++ { 530 ci := c[i*ldc : i*ldc+i+1] 531 ai := a[i*lda : i*lda+k] 532 switch { 533 case beta == 0: 534 // Handle the first i-1 elements on the i-th row of C. 535 for j := range ci[:i] { 536 ci[j] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai) 537 } 538 // Handle the i-th diagonal element of C. 539 ci[i] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0) 540 case beta != 1: 541 for j, cij := range ci[:i] { 542 ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij 543 } 544 cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[i] 545 ci[i] = complex(real(cii), 0) 546 default: 547 for j, cij := range ci[:i] { 548 ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij 549 } 550 cii := calpha*c128.DotcUnitary(ai, ai) + ci[i] 551 ci[i] = complex(real(cii), 0) 552 } 553 } 554 } 555 } else { 556 // Form C = alpha*Aᴴ*A + beta*C. 557 if uplo == blas.Upper { 558 for i := 0; i < n; i++ { 559 ci := c[i*ldc+i : i*ldc+n] 560 switch { 561 case beta == 0: 562 for jc := range ci { 563 ci[jc] = 0 564 } 565 case beta != 1: 566 c128.DscalUnitary(beta, ci) 567 ci[0] = complex(real(ci[0]), 0) 568 default: 569 ci[0] = complex(real(ci[0]), 0) 570 } 571 for j := 0; j < k; j++ { 572 aji := cmplx.Conj(a[j*lda+i]) 573 if aji != 0 { 574 c128.AxpyUnitary(calpha*aji, a[j*lda+i:j*lda+n], ci) 575 } 576 } 577 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 578 } 579 } else { 580 for i := 0; i < n; i++ { 581 ci := c[i*ldc : i*ldc+i+1] 582 switch { 583 case beta == 0: 584 for j := range ci { 585 ci[j] = 0 586 } 587 case beta != 1: 588 c128.DscalUnitary(beta, ci) 589 ci[i] = complex(real(ci[i]), 0) 590 default: 591 ci[i] = complex(real(ci[i]), 0) 592 } 593 for j := 0; j < k; j++ { 594 aji := cmplx.Conj(a[j*lda+i]) 595 if aji != 0 { 596 c128.AxpyUnitary(calpha*aji, a[j*lda:j*lda+i+1], ci) 597 } 598 } 599 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 600 } 601 } 602 } 603} 604 605// Zher2k performs one of the hermitian rank-2k operations 606// C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C if trans == blas.NoTrans 607// C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C if trans == blas.ConjTrans 608// where alpha and beta are scalars with beta real, C is an n×n hermitian matrix 609// and A and B are n×k matrices in the first case and k×n matrices in the second case. 610// 611// The imaginary parts of the diagonal elements of C are assumed to be zero, and 612// on return they will be set to zero. 613func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) { 614 var row, col int 615 switch trans { 616 default: 617 panic(badTranspose) 618 case blas.NoTrans: 619 row, col = n, k 620 case blas.ConjTrans: 621 row, col = k, n 622 } 623 switch { 624 case uplo != blas.Lower && uplo != blas.Upper: 625 panic(badUplo) 626 case n < 0: 627 panic(nLT0) 628 case k < 0: 629 panic(kLT0) 630 case lda < max(1, col): 631 panic(badLdA) 632 case ldb < max(1, col): 633 panic(badLdB) 634 case ldc < max(1, n): 635 panic(badLdC) 636 } 637 638 // Quick return if possible. 639 if n == 0 { 640 return 641 } 642 643 // For zero matrix size the following slice length checks are trivially satisfied. 644 if len(a) < (row-1)*lda+col { 645 panic(shortA) 646 } 647 if len(b) < (row-1)*ldb+col { 648 panic(shortB) 649 } 650 if len(c) < (n-1)*ldc+n { 651 panic(shortC) 652 } 653 654 // Quick return if possible. 655 if (alpha == 0 || k == 0) && beta == 1 { 656 return 657 } 658 659 if alpha == 0 { 660 if uplo == blas.Upper { 661 if beta == 0 { 662 for i := 0; i < n; i++ { 663 ci := c[i*ldc+i : i*ldc+n] 664 for j := range ci { 665 ci[j] = 0 666 } 667 } 668 } else { 669 for i := 0; i < n; i++ { 670 ci := c[i*ldc+i : i*ldc+n] 671 ci[0] = complex(beta*real(ci[0]), 0) 672 if i != n-1 { 673 c128.DscalUnitary(beta, ci[1:]) 674 } 675 } 676 } 677 } else { 678 if beta == 0 { 679 for i := 0; i < n; i++ { 680 ci := c[i*ldc : i*ldc+i+1] 681 for j := range ci { 682 ci[j] = 0 683 } 684 } 685 } else { 686 for i := 0; i < n; i++ { 687 ci := c[i*ldc : i*ldc+i+1] 688 if i != 0 { 689 c128.DscalUnitary(beta, ci[:i]) 690 } 691 ci[i] = complex(beta*real(ci[i]), 0) 692 } 693 } 694 } 695 return 696 } 697 698 conjalpha := cmplx.Conj(alpha) 699 cbeta := complex(beta, 0) 700 if trans == blas.NoTrans { 701 // Form C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C. 702 if uplo == blas.Upper { 703 for i := 0; i < n; i++ { 704 ci := c[i*ldc+i+1 : i*ldc+n] 705 ai := a[i*lda : i*lda+k] 706 bi := b[i*ldb : i*ldb+k] 707 if beta == 0 { 708 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) 709 c[i*ldc+i] = complex(real(cii), 0) 710 for jc := range ci { 711 j := i + 1 + jc 712 ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) 713 } 714 } else { 715 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i] 716 c[i*ldc+i] = complex(real(cii), 0) 717 for jc, cij := range ci { 718 j := i + 1 + jc 719 ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij 720 } 721 } 722 } 723 } else { 724 for i := 0; i < n; i++ { 725 ci := c[i*ldc : i*ldc+i] 726 ai := a[i*lda : i*lda+k] 727 bi := b[i*ldb : i*ldb+k] 728 if beta == 0 { 729 for j := range ci { 730 ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) 731 } 732 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) 733 c[i*ldc+i] = complex(real(cii), 0) 734 } else { 735 for j, cij := range ci { 736 ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij 737 } 738 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i] 739 c[i*ldc+i] = complex(real(cii), 0) 740 } 741 } 742 } 743 } else { 744 // Form C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C. 745 if uplo == blas.Upper { 746 for i := 0; i < n; i++ { 747 ci := c[i*ldc+i : i*ldc+n] 748 switch { 749 case beta == 0: 750 for jc := range ci { 751 ci[jc] = 0 752 } 753 case beta != 1: 754 c128.DscalUnitary(beta, ci) 755 ci[0] = complex(real(ci[0]), 0) 756 default: 757 ci[0] = complex(real(ci[0]), 0) 758 } 759 for j := 0; j < k; j++ { 760 aji := a[j*lda+i] 761 bji := b[j*ldb+i] 762 if aji != 0 { 763 c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci) 764 } 765 if bji != 0 { 766 c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci) 767 } 768 } 769 ci[0] = complex(real(ci[0]), 0) 770 } 771 } else { 772 for i := 0; i < n; i++ { 773 ci := c[i*ldc : i*ldc+i+1] 774 switch { 775 case beta == 0: 776 for j := range ci { 777 ci[j] = 0 778 } 779 case beta != 1: 780 c128.DscalUnitary(beta, ci) 781 ci[i] = complex(real(ci[i]), 0) 782 default: 783 ci[i] = complex(real(ci[i]), 0) 784 } 785 for j := 0; j < k; j++ { 786 aji := a[j*lda+i] 787 bji := b[j*ldb+i] 788 if aji != 0 { 789 c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci) 790 } 791 if bji != 0 { 792 c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci) 793 } 794 } 795 ci[i] = complex(real(ci[i]), 0) 796 } 797 } 798 } 799} 800 801// Zsymm performs one of the matrix-matrix operations 802// C = alpha*A*B + beta*C if side == blas.Left 803// C = alpha*B*A + beta*C if side == blas.Right 804// where alpha and beta are scalars, A is an m×m or n×n symmetric matrix and B 805// and C are m×n matrices. 806func (Implementation) Zsymm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 807 na := m 808 if side == blas.Right { 809 na = n 810 } 811 switch { 812 case side != blas.Left && side != blas.Right: 813 panic(badSide) 814 case uplo != blas.Lower && uplo != blas.Upper: 815 panic(badUplo) 816 case m < 0: 817 panic(mLT0) 818 case n < 0: 819 panic(nLT0) 820 case lda < max(1, na): 821 panic(badLdA) 822 case ldb < max(1, n): 823 panic(badLdB) 824 case ldc < max(1, n): 825 panic(badLdC) 826 } 827 828 // Quick return if possible. 829 if m == 0 || n == 0 { 830 return 831 } 832 833 // For zero matrix size the following slice length checks are trivially satisfied. 834 if len(a) < lda*(na-1)+na { 835 panic(shortA) 836 } 837 if len(b) < ldb*(m-1)+n { 838 panic(shortB) 839 } 840 if len(c) < ldc*(m-1)+n { 841 panic(shortC) 842 } 843 844 // Quick return if possible. 845 if alpha == 0 && beta == 1 { 846 return 847 } 848 849 if alpha == 0 { 850 if beta == 0 { 851 for i := 0; i < m; i++ { 852 ci := c[i*ldc : i*ldc+n] 853 for j := range ci { 854 ci[j] = 0 855 } 856 } 857 } else { 858 for i := 0; i < m; i++ { 859 ci := c[i*ldc : i*ldc+n] 860 c128.ScalUnitary(beta, ci) 861 } 862 } 863 return 864 } 865 866 if side == blas.Left { 867 // Form C = alpha*A*B + beta*C. 868 for i := 0; i < m; i++ { 869 atmp := alpha * a[i*lda+i] 870 bi := b[i*ldb : i*ldb+n] 871 ci := c[i*ldc : i*ldc+n] 872 if beta == 0 { 873 for j, bij := range bi { 874 ci[j] = atmp * bij 875 } 876 } else { 877 for j, bij := range bi { 878 ci[j] = atmp*bij + beta*ci[j] 879 } 880 } 881 if uplo == blas.Upper { 882 for k := 0; k < i; k++ { 883 atmp = alpha * a[k*lda+i] 884 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 885 } 886 for k := i + 1; k < m; k++ { 887 atmp = alpha * a[i*lda+k] 888 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 889 } 890 } else { 891 for k := 0; k < i; k++ { 892 atmp = alpha * a[i*lda+k] 893 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 894 } 895 for k := i + 1; k < m; k++ { 896 atmp = alpha * a[k*lda+i] 897 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 898 } 899 } 900 } 901 } else { 902 // Form C = alpha*B*A + beta*C. 903 if uplo == blas.Upper { 904 for i := 0; i < m; i++ { 905 for j := n - 1; j >= 0; j-- { 906 abij := alpha * b[i*ldb+j] 907 aj := a[j*lda+j+1 : j*lda+n] 908 bi := b[i*ldb+j+1 : i*ldb+n] 909 ci := c[i*ldc+j+1 : i*ldc+n] 910 var tmp complex128 911 for k, ajk := range aj { 912 ci[k] += abij * ajk 913 tmp += bi[k] * ajk 914 } 915 if beta == 0 { 916 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp 917 } else { 918 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j] 919 } 920 } 921 } 922 } else { 923 for i := 0; i < m; i++ { 924 for j := 0; j < n; j++ { 925 abij := alpha * b[i*ldb+j] 926 aj := a[j*lda : j*lda+j] 927 bi := b[i*ldb : i*ldb+j] 928 ci := c[i*ldc : i*ldc+j] 929 var tmp complex128 930 for k, ajk := range aj { 931 ci[k] += abij * ajk 932 tmp += bi[k] * ajk 933 } 934 if beta == 0 { 935 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp 936 } else { 937 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j] 938 } 939 } 940 } 941 } 942 } 943} 944 945// Zsyrk performs one of the symmetric rank-k operations 946// C = alpha*A*Aᵀ + beta*C if trans == blas.NoTrans 947// C = alpha*Aᵀ*A + beta*C if trans == blas.Trans 948// where alpha and beta are scalars, C is an n×n symmetric matrix and A is 949// an n×k matrix in the first case and a k×n matrix in the second case. 950func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) { 951 var rowA, colA int 952 switch trans { 953 default: 954 panic(badTranspose) 955 case blas.NoTrans: 956 rowA, colA = n, k 957 case blas.Trans: 958 rowA, colA = k, n 959 } 960 switch { 961 case uplo != blas.Lower && uplo != blas.Upper: 962 panic(badUplo) 963 case n < 0: 964 panic(nLT0) 965 case k < 0: 966 panic(kLT0) 967 case lda < max(1, colA): 968 panic(badLdA) 969 case ldc < max(1, n): 970 panic(badLdC) 971 } 972 973 // Quick return if possible. 974 if n == 0 { 975 return 976 } 977 978 // For zero matrix size the following slice length checks are trivially satisfied. 979 if len(a) < (rowA-1)*lda+colA { 980 panic(shortA) 981 } 982 if len(c) < (n-1)*ldc+n { 983 panic(shortC) 984 } 985 986 // Quick return if possible. 987 if (alpha == 0 || k == 0) && beta == 1 { 988 return 989 } 990 991 if alpha == 0 { 992 if uplo == blas.Upper { 993 if beta == 0 { 994 for i := 0; i < n; i++ { 995 ci := c[i*ldc+i : i*ldc+n] 996 for j := range ci { 997 ci[j] = 0 998 } 999 } 1000 } else { 1001 for i := 0; i < n; i++ { 1002 ci := c[i*ldc+i : i*ldc+n] 1003 c128.ScalUnitary(beta, ci) 1004 } 1005 } 1006 } else { 1007 if beta == 0 { 1008 for i := 0; i < n; i++ { 1009 ci := c[i*ldc : i*ldc+i+1] 1010 for j := range ci { 1011 ci[j] = 0 1012 } 1013 } 1014 } else { 1015 for i := 0; i < n; i++ { 1016 ci := c[i*ldc : i*ldc+i+1] 1017 c128.ScalUnitary(beta, ci) 1018 } 1019 } 1020 } 1021 return 1022 } 1023 1024 if trans == blas.NoTrans { 1025 // Form C = alpha*A*Aᵀ + beta*C. 1026 if uplo == blas.Upper { 1027 for i := 0; i < n; i++ { 1028 ci := c[i*ldc+i : i*ldc+n] 1029 ai := a[i*lda : i*lda+k] 1030 for jc, cij := range ci { 1031 j := i + jc 1032 ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1033 } 1034 } 1035 } else { 1036 for i := 0; i < n; i++ { 1037 ci := c[i*ldc : i*ldc+i+1] 1038 ai := a[i*lda : i*lda+k] 1039 for j, cij := range ci { 1040 ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1041 } 1042 } 1043 } 1044 } else { 1045 // Form C = alpha*Aᵀ*A + beta*C. 1046 if uplo == blas.Upper { 1047 for i := 0; i < n; i++ { 1048 ci := c[i*ldc+i : i*ldc+n] 1049 switch { 1050 case beta == 0: 1051 for jc := range ci { 1052 ci[jc] = 0 1053 } 1054 case beta != 1: 1055 for jc := range ci { 1056 ci[jc] *= beta 1057 } 1058 } 1059 for j := 0; j < k; j++ { 1060 aji := a[j*lda+i] 1061 if aji != 0 { 1062 c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci) 1063 } 1064 } 1065 } 1066 } else { 1067 for i := 0; i < n; i++ { 1068 ci := c[i*ldc : i*ldc+i+1] 1069 switch { 1070 case beta == 0: 1071 for j := range ci { 1072 ci[j] = 0 1073 } 1074 case beta != 1: 1075 for j := range ci { 1076 ci[j] *= beta 1077 } 1078 } 1079 for j := 0; j < k; j++ { 1080 aji := a[j*lda+i] 1081 if aji != 0 { 1082 c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci) 1083 } 1084 } 1085 } 1086 } 1087 } 1088} 1089 1090// Zsyr2k performs one of the symmetric rank-2k operations 1091// C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C if trans == blas.NoTrans 1092// C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C if trans == blas.Trans 1093// where alpha and beta are scalars, C is an n×n symmetric matrix and A and B 1094// are n×k matrices in the first case and k×n matrices in the second case. 1095func (Implementation) Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 1096 var row, col int 1097 switch trans { 1098 default: 1099 panic(badTranspose) 1100 case blas.NoTrans: 1101 row, col = n, k 1102 case blas.Trans: 1103 row, col = k, n 1104 } 1105 switch { 1106 case uplo != blas.Lower && uplo != blas.Upper: 1107 panic(badUplo) 1108 case n < 0: 1109 panic(nLT0) 1110 case k < 0: 1111 panic(kLT0) 1112 case lda < max(1, col): 1113 panic(badLdA) 1114 case ldb < max(1, col): 1115 panic(badLdB) 1116 case ldc < max(1, n): 1117 panic(badLdC) 1118 } 1119 1120 // Quick return if possible. 1121 if n == 0 { 1122 return 1123 } 1124 1125 // For zero matrix size the following slice length checks are trivially satisfied. 1126 if len(a) < (row-1)*lda+col { 1127 panic(shortA) 1128 } 1129 if len(b) < (row-1)*ldb+col { 1130 panic(shortB) 1131 } 1132 if len(c) < (n-1)*ldc+n { 1133 panic(shortC) 1134 } 1135 1136 // Quick return if possible. 1137 if (alpha == 0 || k == 0) && beta == 1 { 1138 return 1139 } 1140 1141 if alpha == 0 { 1142 if uplo == blas.Upper { 1143 if beta == 0 { 1144 for i := 0; i < n; i++ { 1145 ci := c[i*ldc+i : i*ldc+n] 1146 for j := range ci { 1147 ci[j] = 0 1148 } 1149 } 1150 } else { 1151 for i := 0; i < n; i++ { 1152 ci := c[i*ldc+i : i*ldc+n] 1153 c128.ScalUnitary(beta, ci) 1154 } 1155 } 1156 } else { 1157 if beta == 0 { 1158 for i := 0; i < n; i++ { 1159 ci := c[i*ldc : i*ldc+i+1] 1160 for j := range ci { 1161 ci[j] = 0 1162 } 1163 } 1164 } else { 1165 for i := 0; i < n; i++ { 1166 ci := c[i*ldc : i*ldc+i+1] 1167 c128.ScalUnitary(beta, ci) 1168 } 1169 } 1170 } 1171 return 1172 } 1173 1174 if trans == blas.NoTrans { 1175 // Form C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C. 1176 if uplo == blas.Upper { 1177 for i := 0; i < n; i++ { 1178 ci := c[i*ldc+i : i*ldc+n] 1179 ai := a[i*lda : i*lda+k] 1180 bi := b[i*ldb : i*ldb+k] 1181 if beta == 0 { 1182 for jc := range ci { 1183 j := i + jc 1184 ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) 1185 } 1186 } else { 1187 for jc, cij := range ci { 1188 j := i + jc 1189 ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij 1190 } 1191 } 1192 } 1193 } else { 1194 for i := 0; i < n; i++ { 1195 ci := c[i*ldc : i*ldc+i+1] 1196 ai := a[i*lda : i*lda+k] 1197 bi := b[i*ldb : i*ldb+k] 1198 if beta == 0 { 1199 for j := range ci { 1200 ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) 1201 } 1202 } else { 1203 for j, cij := range ci { 1204 ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij 1205 } 1206 } 1207 } 1208 } 1209 } else { 1210 // Form C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C. 1211 if uplo == blas.Upper { 1212 for i := 0; i < n; i++ { 1213 ci := c[i*ldc+i : i*ldc+n] 1214 switch { 1215 case beta == 0: 1216 for jc := range ci { 1217 ci[jc] = 0 1218 } 1219 case beta != 1: 1220 for jc := range ci { 1221 ci[jc] *= beta 1222 } 1223 } 1224 for j := 0; j < k; j++ { 1225 aji := a[j*lda+i] 1226 bji := b[j*ldb+i] 1227 if aji != 0 { 1228 c128.AxpyUnitary(alpha*aji, b[j*ldb+i:j*ldb+n], ci) 1229 } 1230 if bji != 0 { 1231 c128.AxpyUnitary(alpha*bji, a[j*lda+i:j*lda+n], ci) 1232 } 1233 } 1234 } 1235 } else { 1236 for i := 0; i < n; i++ { 1237 ci := c[i*ldc : i*ldc+i+1] 1238 switch { 1239 case beta == 0: 1240 for j := range ci { 1241 ci[j] = 0 1242 } 1243 case beta != 1: 1244 for j := range ci { 1245 ci[j] *= beta 1246 } 1247 } 1248 for j := 0; j < k; j++ { 1249 aji := a[j*lda+i] 1250 bji := b[j*ldb+i] 1251 if aji != 0 { 1252 c128.AxpyUnitary(alpha*aji, b[j*ldb:j*ldb+i+1], ci) 1253 } 1254 if bji != 0 { 1255 c128.AxpyUnitary(alpha*bji, a[j*lda:j*lda+i+1], ci) 1256 } 1257 } 1258 } 1259 } 1260 } 1261} 1262 1263// Ztrmm performs one of the matrix-matrix operations 1264// B = alpha * op(A) * B if side == blas.Left, 1265// B = alpha * B * op(A) if side == blas.Right, 1266// where alpha is a scalar, B is an m×n matrix, A is a unit, or non-unit, 1267// upper or lower triangular matrix and op(A) is one of 1268// op(A) = A if trans == blas.NoTrans, 1269// op(A) = Aᵀ if trans == blas.Trans, 1270// op(A) = Aᴴ if trans == blas.ConjTrans. 1271func (Implementation) Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) { 1272 na := m 1273 if side == blas.Right { 1274 na = n 1275 } 1276 switch { 1277 case side != blas.Left && side != blas.Right: 1278 panic(badSide) 1279 case uplo != blas.Lower && uplo != blas.Upper: 1280 panic(badUplo) 1281 case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: 1282 panic(badTranspose) 1283 case diag != blas.Unit && diag != blas.NonUnit: 1284 panic(badDiag) 1285 case m < 0: 1286 panic(mLT0) 1287 case n < 0: 1288 panic(nLT0) 1289 case lda < max(1, na): 1290 panic(badLdA) 1291 case ldb < max(1, n): 1292 panic(badLdB) 1293 } 1294 1295 // Quick return if possible. 1296 if m == 0 || n == 0 { 1297 return 1298 } 1299 1300 // For zero matrix size the following slice length checks are trivially satisfied. 1301 if len(a) < (na-1)*lda+na { 1302 panic(shortA) 1303 } 1304 if len(b) < (m-1)*ldb+n { 1305 panic(shortB) 1306 } 1307 1308 // Quick return if possible. 1309 if alpha == 0 { 1310 for i := 0; i < m; i++ { 1311 bi := b[i*ldb : i*ldb+n] 1312 for j := range bi { 1313 bi[j] = 0 1314 } 1315 } 1316 return 1317 } 1318 1319 noConj := trans != blas.ConjTrans 1320 noUnit := diag == blas.NonUnit 1321 if side == blas.Left { 1322 if trans == blas.NoTrans { 1323 // Form B = alpha*A*B. 1324 if uplo == blas.Upper { 1325 for i := 0; i < m; i++ { 1326 aii := alpha 1327 if noUnit { 1328 aii *= a[i*lda+i] 1329 } 1330 bi := b[i*ldb : i*ldb+n] 1331 for j := range bi { 1332 bi[j] *= aii 1333 } 1334 for ja, aij := range a[i*lda+i+1 : i*lda+m] { 1335 j := ja + i + 1 1336 if aij != 0 { 1337 c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi) 1338 } 1339 } 1340 } 1341 } else { 1342 for i := m - 1; i >= 0; i-- { 1343 aii := alpha 1344 if noUnit { 1345 aii *= a[i*lda+i] 1346 } 1347 bi := b[i*ldb : i*ldb+n] 1348 for j := range bi { 1349 bi[j] *= aii 1350 } 1351 for j, aij := range a[i*lda : i*lda+i] { 1352 if aij != 0 { 1353 c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi) 1354 } 1355 } 1356 } 1357 } 1358 } else { 1359 // Form B = alpha*Aᵀ*B or B = alpha*Aᴴ*B. 1360 if uplo == blas.Upper { 1361 for k := m - 1; k >= 0; k-- { 1362 bk := b[k*ldb : k*ldb+n] 1363 for ja, ajk := range a[k*lda+k+1 : k*lda+m] { 1364 if ajk == 0 { 1365 continue 1366 } 1367 j := k + 1 + ja 1368 if noConj { 1369 c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n]) 1370 } else { 1371 c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n]) 1372 } 1373 } 1374 akk := alpha 1375 if noUnit { 1376 if noConj { 1377 akk *= a[k*lda+k] 1378 } else { 1379 akk *= cmplx.Conj(a[k*lda+k]) 1380 } 1381 } 1382 if akk != 1 { 1383 c128.ScalUnitary(akk, bk) 1384 } 1385 } 1386 } else { 1387 for k := 0; k < m; k++ { 1388 bk := b[k*ldb : k*ldb+n] 1389 for j, ajk := range a[k*lda : k*lda+k] { 1390 if ajk == 0 { 1391 continue 1392 } 1393 if noConj { 1394 c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n]) 1395 } else { 1396 c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n]) 1397 } 1398 } 1399 akk := alpha 1400 if noUnit { 1401 if noConj { 1402 akk *= a[k*lda+k] 1403 } else { 1404 akk *= cmplx.Conj(a[k*lda+k]) 1405 } 1406 } 1407 if akk != 1 { 1408 c128.ScalUnitary(akk, bk) 1409 } 1410 } 1411 } 1412 } 1413 } else { 1414 if trans == blas.NoTrans { 1415 // Form B = alpha*B*A. 1416 if uplo == blas.Upper { 1417 for i := 0; i < m; i++ { 1418 bi := b[i*ldb : i*ldb+n] 1419 for k := n - 1; k >= 0; k-- { 1420 abik := alpha * bi[k] 1421 if abik == 0 { 1422 continue 1423 } 1424 bi[k] = abik 1425 if noUnit { 1426 bi[k] *= a[k*lda+k] 1427 } 1428 c128.AxpyUnitary(abik, a[k*lda+k+1:k*lda+n], bi[k+1:]) 1429 } 1430 } 1431 } else { 1432 for i := 0; i < m; i++ { 1433 bi := b[i*ldb : i*ldb+n] 1434 for k := 0; k < n; k++ { 1435 abik := alpha * bi[k] 1436 if abik == 0 { 1437 continue 1438 } 1439 bi[k] = abik 1440 if noUnit { 1441 bi[k] *= a[k*lda+k] 1442 } 1443 c128.AxpyUnitary(abik, a[k*lda:k*lda+k], bi[:k]) 1444 } 1445 } 1446 } 1447 } else { 1448 // Form B = alpha*B*Aᵀ or B = alpha*B*Aᴴ. 1449 if uplo == blas.Upper { 1450 for i := 0; i < m; i++ { 1451 bi := b[i*ldb : i*ldb+n] 1452 for j, bij := range bi { 1453 if noConj { 1454 if noUnit { 1455 bij *= a[j*lda+j] 1456 } 1457 bij += c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1458 } else { 1459 if noUnit { 1460 bij *= cmplx.Conj(a[j*lda+j]) 1461 } 1462 bij += c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1463 } 1464 bi[j] = alpha * bij 1465 } 1466 } 1467 } else { 1468 for i := 0; i < m; i++ { 1469 bi := b[i*ldb : i*ldb+n] 1470 for j := n - 1; j >= 0; j-- { 1471 bij := bi[j] 1472 if noConj { 1473 if noUnit { 1474 bij *= a[j*lda+j] 1475 } 1476 bij += c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j]) 1477 } else { 1478 if noUnit { 1479 bij *= cmplx.Conj(a[j*lda+j]) 1480 } 1481 bij += c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j]) 1482 } 1483 bi[j] = alpha * bij 1484 } 1485 } 1486 } 1487 } 1488 } 1489} 1490 1491// Ztrsm solves one of the matrix equations 1492// op(A) * X = alpha * B if side == blas.Left, 1493// X * op(A) = alpha * B if side == blas.Right, 1494// where alpha is a scalar, X and B are m×n matrices, A is a unit or 1495// non-unit, upper or lower triangular matrix and op(A) is one of 1496// op(A) = A if transA == blas.NoTrans, 1497// op(A) = Aᵀ if transA == blas.Trans, 1498// op(A) = Aᴴ if transA == blas.ConjTrans. 1499// On return the matrix X is overwritten on B. 1500func (Implementation) Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) { 1501 na := m 1502 if side == blas.Right { 1503 na = n 1504 } 1505 switch { 1506 case side != blas.Left && side != blas.Right: 1507 panic(badSide) 1508 case uplo != blas.Lower && uplo != blas.Upper: 1509 panic(badUplo) 1510 case transA != blas.NoTrans && transA != blas.Trans && transA != blas.ConjTrans: 1511 panic(badTranspose) 1512 case diag != blas.Unit && diag != blas.NonUnit: 1513 panic(badDiag) 1514 case m < 0: 1515 panic(mLT0) 1516 case n < 0: 1517 panic(nLT0) 1518 case lda < max(1, na): 1519 panic(badLdA) 1520 case ldb < max(1, n): 1521 panic(badLdB) 1522 } 1523 1524 // Quick return if possible. 1525 if m == 0 || n == 0 { 1526 return 1527 } 1528 1529 // For zero matrix size the following slice length checks are trivially satisfied. 1530 if len(a) < (na-1)*lda+na { 1531 panic(shortA) 1532 } 1533 if len(b) < (m-1)*ldb+n { 1534 panic(shortB) 1535 } 1536 1537 if alpha == 0 { 1538 for i := 0; i < m; i++ { 1539 for j := 0; j < n; j++ { 1540 b[i*ldb+j] = 0 1541 } 1542 } 1543 return 1544 } 1545 1546 noConj := transA != blas.ConjTrans 1547 noUnit := diag == blas.NonUnit 1548 if side == blas.Left { 1549 if transA == blas.NoTrans { 1550 // Form B = alpha*inv(A)*B. 1551 if uplo == blas.Upper { 1552 for i := m - 1; i >= 0; i-- { 1553 bi := b[i*ldb : i*ldb+n] 1554 if alpha != 1 { 1555 c128.ScalUnitary(alpha, bi) 1556 } 1557 for ka, aik := range a[i*lda+i+1 : i*lda+m] { 1558 k := i + 1 + ka 1559 if aik != 0 { 1560 c128.AxpyUnitary(-aik, b[k*ldb:k*ldb+n], bi) 1561 } 1562 } 1563 if noUnit { 1564 c128.ScalUnitary(1/a[i*lda+i], bi) 1565 } 1566 } 1567 } else { 1568 for i := 0; i < m; i++ { 1569 bi := b[i*ldb : i*ldb+n] 1570 if alpha != 1 { 1571 c128.ScalUnitary(alpha, bi) 1572 } 1573 for j, aij := range a[i*lda : i*lda+i] { 1574 if aij != 0 { 1575 c128.AxpyUnitary(-aij, b[j*ldb:j*ldb+n], bi) 1576 } 1577 } 1578 if noUnit { 1579 c128.ScalUnitary(1/a[i*lda+i], bi) 1580 } 1581 } 1582 } 1583 } else { 1584 // Form B = alpha*inv(Aᵀ)*B or B = alpha*inv(Aᴴ)*B. 1585 if uplo == blas.Upper { 1586 for i := 0; i < m; i++ { 1587 bi := b[i*ldb : i*ldb+n] 1588 if noUnit { 1589 if noConj { 1590 c128.ScalUnitary(1/a[i*lda+i], bi) 1591 } else { 1592 c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi) 1593 } 1594 } 1595 for ja, aij := range a[i*lda+i+1 : i*lda+m] { 1596 if aij == 0 { 1597 continue 1598 } 1599 j := i + 1 + ja 1600 if noConj { 1601 c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n]) 1602 } else { 1603 c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n]) 1604 } 1605 } 1606 if alpha != 1 { 1607 c128.ScalUnitary(alpha, bi) 1608 } 1609 } 1610 } else { 1611 for i := m - 1; i >= 0; i-- { 1612 bi := b[i*ldb : i*ldb+n] 1613 if noUnit { 1614 if noConj { 1615 c128.ScalUnitary(1/a[i*lda+i], bi) 1616 } else { 1617 c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi) 1618 } 1619 } 1620 for j, aij := range a[i*lda : i*lda+i] { 1621 if aij == 0 { 1622 continue 1623 } 1624 if noConj { 1625 c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n]) 1626 } else { 1627 c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n]) 1628 } 1629 } 1630 if alpha != 1 { 1631 c128.ScalUnitary(alpha, bi) 1632 } 1633 } 1634 } 1635 } 1636 } else { 1637 if transA == blas.NoTrans { 1638 // Form B = alpha*B*inv(A). 1639 if uplo == blas.Upper { 1640 for i := 0; i < m; i++ { 1641 bi := b[i*ldb : i*ldb+n] 1642 if alpha != 1 { 1643 c128.ScalUnitary(alpha, bi) 1644 } 1645 for j, bij := range bi { 1646 if bij == 0 { 1647 continue 1648 } 1649 if noUnit { 1650 bi[j] /= a[j*lda+j] 1651 } 1652 c128.AxpyUnitary(-bi[j], a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1653 } 1654 } 1655 } else { 1656 for i := 0; i < m; i++ { 1657 bi := b[i*ldb : i*ldb+n] 1658 if alpha != 1 { 1659 c128.ScalUnitary(alpha, bi) 1660 } 1661 for j := n - 1; j >= 0; j-- { 1662 if bi[j] == 0 { 1663 continue 1664 } 1665 if noUnit { 1666 bi[j] /= a[j*lda+j] 1667 } 1668 c128.AxpyUnitary(-bi[j], a[j*lda:j*lda+j], bi[:j]) 1669 } 1670 } 1671 } 1672 } else { 1673 // Form B = alpha*B*inv(Aᵀ) or B = alpha*B*inv(Aᴴ). 1674 if uplo == blas.Upper { 1675 for i := 0; i < m; i++ { 1676 bi := b[i*ldb : i*ldb+n] 1677 for j := n - 1; j >= 0; j-- { 1678 bij := alpha * bi[j] 1679 if noConj { 1680 bij -= c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1681 if noUnit { 1682 bij /= a[j*lda+j] 1683 } 1684 } else { 1685 bij -= c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1686 if noUnit { 1687 bij /= cmplx.Conj(a[j*lda+j]) 1688 } 1689 } 1690 bi[j] = bij 1691 } 1692 } 1693 } else { 1694 for i := 0; i < m; i++ { 1695 bi := b[i*ldb : i*ldb+n] 1696 for j, bij := range bi { 1697 bij *= alpha 1698 if noConj { 1699 bij -= c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j]) 1700 if noUnit { 1701 bij /= a[j*lda+j] 1702 } 1703 } else { 1704 bij -= c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j]) 1705 if noUnit { 1706 bij /= cmplx.Conj(a[j*lda+j]) 1707 } 1708 } 1709 bi[j] = bij 1710 } 1711 } 1712 } 1713 } 1714 } 1715} 1716