1// Copyright ©2015 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package cblas128 6 7import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/gonum" 10) 11 12var cblas128 blas.Complex128 = gonum.Implementation{} 13 14// Use sets the BLAS complex128 implementation to be used by subsequent BLAS calls. 15// The default implementation is 16// gonum.org/v1/gonum/blas/gonum.Implementation. 17func Use(b blas.Complex128) { 18 cblas128 = b 19} 20 21// Implementation returns the current BLAS complex128 implementation. 22// 23// Implementation allows direct calls to the current the BLAS complex128 implementation 24// giving finer control of parameters. 25func Implementation() blas.Complex128 { 26 return cblas128 27} 28 29// Vector represents a vector with an associated element increment. 30type Vector struct { 31 N int 32 Inc int 33 Data []complex128 34} 35 36// General represents a matrix using the conventional storage scheme. 37type General struct { 38 Rows, Cols int 39 Stride int 40 Data []complex128 41} 42 43// Band represents a band matrix using the band storage scheme. 44type Band struct { 45 Rows, Cols int 46 KL, KU int 47 Stride int 48 Data []complex128 49} 50 51// Triangular represents a triangular matrix using the conventional storage scheme. 52type Triangular struct { 53 N int 54 Stride int 55 Data []complex128 56 Uplo blas.Uplo 57 Diag blas.Diag 58} 59 60// TriangularBand represents a triangular matrix using the band storage scheme. 61type TriangularBand struct { 62 N, K int 63 Stride int 64 Data []complex128 65 Uplo blas.Uplo 66 Diag blas.Diag 67} 68 69// TriangularPacked represents a triangular matrix using the packed storage scheme. 70type TriangularPacked struct { 71 N int 72 Data []complex128 73 Uplo blas.Uplo 74 Diag blas.Diag 75} 76 77// Symmetric represents a symmetric matrix using the conventional storage scheme. 78type Symmetric struct { 79 N int 80 Stride int 81 Data []complex128 82 Uplo blas.Uplo 83} 84 85// SymmetricBand represents a symmetric matrix using the band storage scheme. 86type SymmetricBand struct { 87 N, K int 88 Stride int 89 Data []complex128 90 Uplo blas.Uplo 91} 92 93// SymmetricPacked represents a symmetric matrix using the packed storage scheme. 94type SymmetricPacked struct { 95 N int 96 Data []complex128 97 Uplo blas.Uplo 98} 99 100// Hermitian represents an Hermitian matrix using the conventional storage scheme. 101type Hermitian Symmetric 102 103// HermitianBand represents an Hermitian matrix using the band storage scheme. 104type HermitianBand SymmetricBand 105 106// HermitianPacked represents an Hermitian matrix using the packed storage scheme. 107type HermitianPacked SymmetricPacked 108 109// Level 1 110 111const ( 112 negInc = "cblas128: negative vector increment" 113 badLength = "cblas128: vector length mismatch" 114) 115 116// Dotu computes the dot product of the two vectors without 117// complex conjugation: 118// xᵀ * y. 119// Dotu will panic if the lengths of x and y do not match. 120func Dotu(x, y Vector) complex128 { 121 if x.N != y.N { 122 panic(badLength) 123 } 124 return cblas128.Zdotu(x.N, x.Data, x.Inc, y.Data, y.Inc) 125} 126 127// Dotc computes the dot product of the two vectors with 128// complex conjugation: 129// xᴴ * y. 130// Dotc will panic if the lengths of x and y do not match. 131func Dotc(x, y Vector) complex128 { 132 if x.N != y.N { 133 panic(badLength) 134 } 135 return cblas128.Zdotc(x.N, x.Data, x.Inc, y.Data, y.Inc) 136} 137 138// Nrm2 computes the Euclidean norm of the vector x: 139// sqrt(\sum_i x[i] * x[i]). 140// 141// Nrm2 will panic if the vector increment is negative. 142func Nrm2(x Vector) float64 { 143 if x.Inc < 0 { 144 panic(negInc) 145 } 146 return cblas128.Dznrm2(x.N, x.Data, x.Inc) 147} 148 149// Asum computes the sum of magnitudes of the real and imaginary parts of 150// elements of the vector x: 151// \sum_i (|Re x[i]| + |Im x[i]|). 152// 153// Asum will panic if the vector increment is negative. 154func Asum(x Vector) float64 { 155 if x.Inc < 0 { 156 panic(negInc) 157 } 158 return cblas128.Dzasum(x.N, x.Data, x.Inc) 159} 160 161// Iamax returns the index of an element of x with the largest sum of 162// magnitudes of the real and imaginary parts (|Re x[i]|+|Im x[i]|). 163// If there are multiple such indices, the earliest is returned. 164// 165// Iamax returns -1 if n == 0. 166// 167// Iamax will panic if the vector increment is negative. 168func Iamax(x Vector) int { 169 if x.Inc < 0 { 170 panic(negInc) 171 } 172 return cblas128.Izamax(x.N, x.Data, x.Inc) 173} 174 175// Swap exchanges the elements of two vectors: 176// x[i], y[i] = y[i], x[i] for all i. 177// Swap will panic if the lengths of x and y do not match. 178func Swap(x, y Vector) { 179 if x.N != y.N { 180 panic(badLength) 181 } 182 cblas128.Zswap(x.N, x.Data, x.Inc, y.Data, y.Inc) 183} 184 185// Copy copies the elements of x into the elements of y: 186// y[i] = x[i] for all i. 187// Copy will panic if the lengths of x and y do not match. 188func Copy(x, y Vector) { 189 if x.N != y.N { 190 panic(badLength) 191 } 192 cblas128.Zcopy(x.N, x.Data, x.Inc, y.Data, y.Inc) 193} 194 195// Axpy computes 196// y = alpha * x + y, 197// where x and y are vectors, and alpha is a scalar. 198// Axpy will panic if the lengths of x and y do not match. 199func Axpy(alpha complex128, x, y Vector) { 200 if x.N != y.N { 201 panic(badLength) 202 } 203 cblas128.Zaxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 204} 205 206// Scal computes 207// x = alpha * x, 208// where x is a vector, and alpha is a scalar. 209// 210// Scal will panic if the vector increment is negative. 211func Scal(alpha complex128, x Vector) { 212 if x.Inc < 0 { 213 panic(negInc) 214 } 215 cblas128.Zscal(x.N, alpha, x.Data, x.Inc) 216} 217 218// Dscal computes 219// x = alpha * x, 220// where x is a vector, and alpha is a real scalar. 221// 222// Dscal will panic if the vector increment is negative. 223func Dscal(alpha float64, x Vector) { 224 if x.Inc < 0 { 225 panic(negInc) 226 } 227 cblas128.Zdscal(x.N, alpha, x.Data, x.Inc) 228} 229 230// Level 2 231 232// Gemv computes 233// y = alpha * A * x + beta * y if t == blas.NoTrans, 234// y = alpha * Aᵀ * x + beta * y if t == blas.Trans, 235// y = alpha * Aᴴ * x + beta * y if t == blas.ConjTrans, 236// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are 237// scalars. 238func Gemv(t blas.Transpose, alpha complex128, a General, x Vector, beta complex128, y Vector) { 239 cblas128.Zgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 240} 241 242// Gbmv computes 243// y = alpha * A * x + beta * y if t == blas.NoTrans, 244// y = alpha * Aᵀ * x + beta * y if t == blas.Trans, 245// y = alpha * Aᴴ * x + beta * y if t == blas.ConjTrans, 246// where A is an m×n band matrix, x and y are vectors, and alpha and beta are 247// scalars. 248func Gbmv(t blas.Transpose, alpha complex128, a Band, x Vector, beta complex128, y Vector) { 249 cblas128.Zgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 250} 251 252// Trmv computes 253// x = A * x if t == blas.NoTrans, 254// x = Aᵀ * x if t == blas.Trans, 255// x = Aᴴ * x if t == blas.ConjTrans, 256// where A is an n×n triangular matrix, and x is a vector. 257func Trmv(t blas.Transpose, a Triangular, x Vector) { 258 cblas128.Ztrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 259} 260 261// Tbmv computes 262// x = A * x if t == blas.NoTrans, 263// x = Aᵀ * x if t == blas.Trans, 264// x = Aᴴ * x if t == blas.ConjTrans, 265// where A is an n×n triangular band matrix, and x is a vector. 266func Tbmv(t blas.Transpose, a TriangularBand, x Vector) { 267 cblas128.Ztbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 268} 269 270// Tpmv computes 271// x = A * x if t == blas.NoTrans, 272// x = Aᵀ * x if t == blas.Trans, 273// x = Aᴴ * x if t == blas.ConjTrans, 274// where A is an n×n triangular matrix in packed format, and x is a vector. 275func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) { 276 cblas128.Ztpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 277} 278 279// Trsv solves 280// A * x = b if t == blas.NoTrans, 281// Aᵀ * x = b if t == blas.Trans, 282// Aᴴ * x = b if t == blas.ConjTrans, 283// where A is an n×n triangular matrix and x is a vector. 284// 285// At entry to the function, x contains the values of b, and the result is 286// stored in-place into x. 287// 288// No test for singularity or near-singularity is included in this 289// routine. Such tests must be performed before calling this routine. 290func Trsv(t blas.Transpose, a Triangular, x Vector) { 291 cblas128.Ztrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 292} 293 294// Tbsv solves 295// A * x = b if t == blas.NoTrans, 296// Aᵀ * x = b if t == blas.Trans, 297// Aᴴ * x = b if t == blas.ConjTrans, 298// where A is an n×n triangular band matrix, and x is a vector. 299// 300// At entry to the function, x contains the values of b, and the result is 301// stored in-place into x. 302// 303// No test for singularity or near-singularity is included in this 304// routine. Such tests must be performed before calling this routine. 305func Tbsv(t blas.Transpose, a TriangularBand, x Vector) { 306 cblas128.Ztbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 307} 308 309// Tpsv solves 310// A * x = b if t == blas.NoTrans, 311// Aᵀ * x = b if t == blas.Trans, 312// Aᴴ * x = b if t == blas.ConjTrans, 313// where A is an n×n triangular matrix in packed format and x is a vector. 314// 315// At entry to the function, x contains the values of b, and the result is 316// stored in-place into x. 317// 318// No test for singularity or near-singularity is included in this 319// routine. Such tests must be performed before calling this routine. 320func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) { 321 cblas128.Ztpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 322} 323 324// Hemv computes 325// y = alpha * A * x + beta * y, 326// where A is an n×n Hermitian matrix, x and y are vectors, and alpha and 327// beta are scalars. 328func Hemv(alpha complex128, a Hermitian, x Vector, beta complex128, y Vector) { 329 cblas128.Zhemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 330} 331 332// Hbmv performs 333// y = alpha * A * x + beta * y, 334// where A is an n×n Hermitian band matrix, x and y are vectors, and alpha 335// and beta are scalars. 336func Hbmv(alpha complex128, a HermitianBand, x Vector, beta complex128, y Vector) { 337 cblas128.Zhbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 338} 339 340// Hpmv performs 341// y = alpha * A * x + beta * y, 342// where A is an n×n Hermitian matrix in packed format, x and y are vectors, 343// and alpha and beta are scalars. 344func Hpmv(alpha complex128, a HermitianPacked, x Vector, beta complex128, y Vector) { 345 cblas128.Zhpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc) 346} 347 348// Geru performs a rank-1 update 349// A += alpha * x * yᵀ, 350// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 351func Geru(alpha complex128, x, y Vector, a General) { 352 cblas128.Zgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 353} 354 355// Gerc performs a rank-1 update 356// A += alpha * x * yᴴ, 357// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 358func Gerc(alpha complex128, x, y Vector, a General) { 359 cblas128.Zgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 360} 361 362// Her performs a rank-1 update 363// A += alpha * x * yᵀ, 364// where A is an m×n Hermitian matrix, x and y are vectors, and alpha is a scalar. 365func Her(alpha float64, x Vector, a Hermitian) { 366 cblas128.Zher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride) 367} 368 369// Hpr performs a rank-1 update 370// A += alpha * x * xᴴ, 371// where A is an n×n Hermitian matrix in packed format, x is a vector, and 372// alpha is a scalar. 373func Hpr(alpha float64, x Vector, a HermitianPacked) { 374 cblas128.Zhpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data) 375} 376 377// Her2 performs a rank-2 update 378// A += alpha * x * yᴴ + conj(alpha) * y * xᴴ, 379// where A is an n×n Hermitian matrix, x and y are vectors, and alpha is a scalar. 380func Her2(alpha complex128, x, y Vector, a Hermitian) { 381 cblas128.Zher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 382} 383 384// Hpr2 performs a rank-2 update 385// A += alpha * x * yᴴ + conj(alpha) * y * xᴴ, 386// where A is an n×n Hermitian matrix in packed format, x and y are vectors, 387// and alpha is a scalar. 388func Hpr2(alpha complex128, x, y Vector, a HermitianPacked) { 389 cblas128.Zhpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data) 390} 391 392// Level 3 393 394// Gemm computes 395// C = alpha * A * B + beta * C, 396// where A, B, and C are dense matrices, and alpha and beta are scalars. 397// tA and tB specify whether A or B are transposed or conjugated. 398func Gemm(tA, tB blas.Transpose, alpha complex128, a, b General, beta complex128, c General) { 399 var m, n, k int 400 if tA == blas.NoTrans { 401 m, k = a.Rows, a.Cols 402 } else { 403 m, k = a.Cols, a.Rows 404 } 405 if tB == blas.NoTrans { 406 n = b.Cols 407 } else { 408 n = b.Rows 409 } 410 cblas128.Zgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 411} 412 413// Symm performs 414// C = alpha * A * B + beta * C if s == blas.Left, 415// C = alpha * B * A + beta * C if s == blas.Right, 416// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and 417// alpha and beta are scalars. 418func Symm(s blas.Side, alpha complex128, a Symmetric, b General, beta complex128, c General) { 419 var m, n int 420 if s == blas.Left { 421 m, n = a.N, b.Cols 422 } else { 423 m, n = b.Rows, a.N 424 } 425 cblas128.Zsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 426} 427 428// Syrk performs a symmetric rank-k update 429// C = alpha * A * Aᵀ + beta * C if t == blas.NoTrans, 430// C = alpha * Aᵀ * A + beta * C if t == blas.Trans, 431// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans 432// and a k×n matrix otherwise, and alpha and beta are scalars. 433func Syrk(t blas.Transpose, alpha complex128, a General, beta complex128, c Symmetric) { 434 var n, k int 435 if t == blas.NoTrans { 436 n, k = a.Rows, a.Cols 437 } else { 438 n, k = a.Cols, a.Rows 439 } 440 cblas128.Zsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 441} 442 443// Syr2k performs a symmetric rank-2k update 444// C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C if t == blas.NoTrans, 445// C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C if t == blas.Trans, 446// where C is an n×n symmetric matrix, A and B are n×k matrices if 447// t == blas.NoTrans and k×n otherwise, and alpha and beta are scalars. 448func Syr2k(t blas.Transpose, alpha complex128, a, b General, beta complex128, c Symmetric) { 449 var n, k int 450 if t == blas.NoTrans { 451 n, k = a.Rows, a.Cols 452 } else { 453 n, k = a.Cols, a.Rows 454 } 455 cblas128.Zsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 456} 457 458// Trmm performs 459// B = alpha * A * B if tA == blas.NoTrans and s == blas.Left, 460// B = alpha * Aᵀ * B if tA == blas.Trans and s == blas.Left, 461// B = alpha * Aᴴ * B if tA == blas.ConjTrans and s == blas.Left, 462// B = alpha * B * A if tA == blas.NoTrans and s == blas.Right, 463// B = alpha * B * Aᵀ if tA == blas.Trans and s == blas.Right, 464// B = alpha * B * Aᴴ if tA == blas.ConjTrans and s == blas.Right, 465// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is 466// a scalar. 467func Trmm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) { 468 cblas128.Ztrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 469} 470 471// Trsm solves 472// A * X = alpha * B if tA == blas.NoTrans and s == blas.Left, 473// Aᵀ * X = alpha * B if tA == blas.Trans and s == blas.Left, 474// Aᴴ * X = alpha * B if tA == blas.ConjTrans and s == blas.Left, 475// X * A = alpha * B if tA == blas.NoTrans and s == blas.Right, 476// X * Aᵀ = alpha * B if tA == blas.Trans and s == blas.Right, 477// X * Aᴴ = alpha * B if tA == blas.ConjTrans and s == blas.Right, 478// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and 479// alpha is a scalar. 480// 481// At entry to the function, b contains the values of B, and the result is 482// stored in-place into b. 483// 484// No check is made that A is invertible. 485func Trsm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) { 486 cblas128.Ztrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 487} 488 489// Hemm performs 490// C = alpha * A * B + beta * C if s == blas.Left, 491// C = alpha * B * A + beta * C if s == blas.Right, 492// where A is an n×n or m×m Hermitian matrix, B and C are m×n matrices, and 493// alpha and beta are scalars. 494func Hemm(s blas.Side, alpha complex128, a Hermitian, b General, beta complex128, c General) { 495 var m, n int 496 if s == blas.Left { 497 m, n = a.N, b.Cols 498 } else { 499 m, n = b.Rows, a.N 500 } 501 cblas128.Zhemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 502} 503 504// Herk performs the Hermitian rank-k update 505// C = alpha * A * Aᴴ + beta*C if t == blas.NoTrans, 506// C = alpha * Aᴴ * A + beta*C if t == blas.ConjTrans, 507// where C is an n×n Hermitian matrix, A is an n×k matrix if t == blas.NoTrans 508// and a k×n matrix otherwise, and alpha and beta are scalars. 509func Herk(t blas.Transpose, alpha float64, a General, beta float64, c Hermitian) { 510 var n, k int 511 if t == blas.NoTrans { 512 n, k = a.Rows, a.Cols 513 } else { 514 n, k = a.Cols, a.Rows 515 } 516 cblas128.Zherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 517} 518 519// Her2k performs the Hermitian rank-2k update 520// C = alpha * A * Bᴴ + conj(alpha) * B * Aᴴ + beta * C if t == blas.NoTrans, 521// C = alpha * Aᴴ * B + conj(alpha) * Bᴴ * A + beta * C if t == blas.ConjTrans, 522// where C is an n×n Hermitian matrix, A and B are n×k matrices if t == NoTrans 523// and k×n matrices otherwise, and alpha and beta are scalars. 524func Her2k(t blas.Transpose, alpha complex128, a, b General, beta float64, c Hermitian) { 525 var n, k int 526 if t == blas.NoTrans { 527 n, k = a.Rows, a.Cols 528 } else { 529 n, k = a.Cols, a.Rows 530 } 531 cblas128.Zher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 532} 533