1// Copyright ©2015 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package blas64 6 7import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/gonum" 10) 11 12var blas64 blas.Float64 = gonum.Implementation{} 13 14// Use sets the BLAS float64 implementation to be used by subsequent BLAS calls. 15// The default implementation is 16// gonum.org/v1/gonum/blas/gonum.Implementation. 17func Use(b blas.Float64) { 18 blas64 = b 19} 20 21// Implementation returns the current BLAS float64 implementation. 22// 23// Implementation allows direct calls to the current the BLAS float64 implementation 24// giving finer control of parameters. 25func Implementation() blas.Float64 { 26 return blas64 27} 28 29// Vector represents a vector with an associated element increment. 30type Vector struct { 31 N int 32 Data []float64 33 Inc int 34} 35 36// General represents a matrix using the conventional storage scheme. 37type General struct { 38 Rows, Cols int 39 Data []float64 40 Stride int 41} 42 43// Band represents a band matrix using the band storage scheme. 44type Band struct { 45 Rows, Cols int 46 KL, KU int 47 Data []float64 48 Stride int 49} 50 51// Triangular represents a triangular matrix using the conventional storage scheme. 52type Triangular struct { 53 Uplo blas.Uplo 54 Diag blas.Diag 55 N int 56 Data []float64 57 Stride int 58} 59 60// TriangularBand represents a triangular matrix using the band storage scheme. 61type TriangularBand struct { 62 Uplo blas.Uplo 63 Diag blas.Diag 64 N, K int 65 Data []float64 66 Stride int 67} 68 69// TriangularPacked represents a triangular matrix using the packed storage scheme. 70type TriangularPacked struct { 71 Uplo blas.Uplo 72 Diag blas.Diag 73 N int 74 Data []float64 75} 76 77// Symmetric represents a symmetric matrix using the conventional storage scheme. 78type Symmetric struct { 79 Uplo blas.Uplo 80 N int 81 Data []float64 82 Stride int 83} 84 85// SymmetricBand represents a symmetric matrix using the band storage scheme. 86type SymmetricBand struct { 87 Uplo blas.Uplo 88 N, K int 89 Data []float64 90 Stride int 91} 92 93// SymmetricPacked represents a symmetric matrix using the packed storage scheme. 94type SymmetricPacked struct { 95 Uplo blas.Uplo 96 N int 97 Data []float64 98} 99 100// Level 1 101 102const ( 103 negInc = "blas64: negative vector increment" 104 badLength = "blas64: vector length mismatch" 105) 106 107// Dot computes the dot product of the two vectors: 108// \sum_i x[i]*y[i]. 109// Dot will panic if the lengths of x and y do not match. 110func Dot(x, y Vector) float64 { 111 if x.N != y.N { 112 panic(badLength) 113 } 114 return blas64.Ddot(x.N, x.Data, x.Inc, y.Data, y.Inc) 115} 116 117// Nrm2 computes the Euclidean norm of the vector x: 118// sqrt(\sum_i x[i]*x[i]). 119// 120// Nrm2 will panic if the vector increment is negative. 121func Nrm2(x Vector) float64 { 122 if x.Inc < 0 { 123 panic(negInc) 124 } 125 return blas64.Dnrm2(x.N, x.Data, x.Inc) 126} 127 128// Asum computes the sum of the absolute values of the elements of x: 129// \sum_i |x[i]|. 130// 131// Asum will panic if the vector increment is negative. 132func Asum(x Vector) float64 { 133 if x.Inc < 0 { 134 panic(negInc) 135 } 136 return blas64.Dasum(x.N, x.Data, x.Inc) 137} 138 139// Iamax returns the index of an element of x with the largest absolute value. 140// If there are multiple such indices the earliest is returned. 141// Iamax returns -1 if n == 0. 142// 143// Iamax will panic if the vector increment is negative. 144func Iamax(x Vector) int { 145 if x.Inc < 0 { 146 panic(negInc) 147 } 148 return blas64.Idamax(x.N, x.Data, x.Inc) 149} 150 151// Swap exchanges the elements of the two vectors: 152// x[i], y[i] = y[i], x[i] for all i. 153// Swap will panic if the lengths of x and y do not match. 154func Swap(x, y Vector) { 155 if x.N != y.N { 156 panic(badLength) 157 } 158 blas64.Dswap(x.N, x.Data, x.Inc, y.Data, y.Inc) 159} 160 161// Copy copies the elements of x into the elements of y: 162// y[i] = x[i] for all i. 163// Copy will panic if the lengths of x and y do not match. 164func Copy(x, y Vector) { 165 if x.N != y.N { 166 panic(badLength) 167 } 168 blas64.Dcopy(x.N, x.Data, x.Inc, y.Data, y.Inc) 169} 170 171// Axpy adds x scaled by alpha to y: 172// y[i] += alpha*x[i] for all i. 173// Axpy will panic if the lengths of x and y do not match. 174func Axpy(alpha float64, x, y Vector) { 175 if x.N != y.N { 176 panic(badLength) 177 } 178 blas64.Daxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc) 179} 180 181// Rotg computes the parameters of a Givens plane rotation so that 182// ⎡ c s⎤ ⎡a⎤ ⎡r⎤ 183// ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦ 184// where a and b are the Cartesian coordinates of a given point. 185// c, s, and r are defined as 186// r = ±Sqrt(a^2 + b^2), 187// c = a/r, the cosine of the rotation angle, 188// s = a/r, the sine of the rotation angle, 189// and z is defined such that 190// if |a| > |b|, z = s, 191// otherwise if c != 0, z = 1/c, 192// otherwise z = 1. 193func Rotg(a, b float64) (c, s, r, z float64) { 194 return blas64.Drotg(a, b) 195} 196 197// Rotmg computes the modified Givens rotation. See 198// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html 199// for more details. 200func Rotmg(d1, d2, b1, b2 float64) (p blas.DrotmParams, rd1, rd2, rb1 float64) { 201 return blas64.Drotmg(d1, d2, b1, b2) 202} 203 204// Rot applies a plane transformation to n points represented by the vectors x 205// and y: 206// x[i] = c*x[i] + s*y[i], 207// y[i] = -s*x[i] + c*y[i], for all i. 208func Rot(x, y Vector, c, s float64) { 209 if x.N != y.N { 210 panic(badLength) 211 } 212 blas64.Drot(x.N, x.Data, x.Inc, y.Data, y.Inc, c, s) 213} 214 215// Rotm applies the modified Givens rotation to n points represented by the 216// vectors x and y. 217func Rotm(x, y Vector, p blas.DrotmParams) { 218 if x.N != y.N { 219 panic(badLength) 220 } 221 blas64.Drotm(x.N, x.Data, x.Inc, y.Data, y.Inc, p) 222} 223 224// Scal scales the vector x by alpha: 225// x[i] *= alpha for all i. 226// 227// Scal will panic if the vector increment is negative. 228func Scal(alpha float64, x Vector) { 229 if x.Inc < 0 { 230 panic(negInc) 231 } 232 blas64.Dscal(x.N, alpha, x.Data, x.Inc) 233} 234 235// Level 2 236 237// Gemv computes 238// y = alpha * A * x + beta * y if t == blas.NoTrans, 239// y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 240// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars. 241func Gemv(t blas.Transpose, alpha float64, a General, x Vector, beta float64, y Vector) { 242 blas64.Dgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 243} 244 245// Gbmv computes 246// y = alpha * A * x + beta * y if t == blas.NoTrans, 247// y = alpha * Aᵀ * x + beta * y if t == blas.Trans or blas.ConjTrans, 248// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars. 249func Gbmv(t blas.Transpose, alpha float64, a Band, x Vector, beta float64, y Vector) { 250 blas64.Dgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 251} 252 253// Trmv computes 254// x = A * x if t == blas.NoTrans, 255// x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 256// where A is an n×n triangular matrix, and x is a vector. 257func Trmv(t blas.Transpose, a Triangular, x Vector) { 258 blas64.Dtrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 259} 260 261// Tbmv computes 262// x = A * x if t == blas.NoTrans, 263// x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 264// where A is an n×n triangular band matrix, and x is a vector. 265func Tbmv(t blas.Transpose, a TriangularBand, x Vector) { 266 blas64.Dtbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 267} 268 269// Tpmv computes 270// x = A * x if t == blas.NoTrans, 271// x = Aᵀ * x if t == blas.Trans or blas.ConjTrans, 272// where A is an n×n triangular matrix in packed format, and x is a vector. 273func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) { 274 blas64.Dtpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 275} 276 277// Trsv solves 278// A * x = b if t == blas.NoTrans, 279// Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 280// where A is an n×n triangular matrix, and x and b are vectors. 281// 282// At entry to the function, x contains the values of b, and the result is 283// stored in-place into x. 284// 285// No test for singularity or near-singularity is included in this 286// routine. Such tests must be performed before calling this routine. 287func Trsv(t blas.Transpose, a Triangular, x Vector) { 288 blas64.Dtrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 289} 290 291// Tbsv solves 292// A * x = b if t == blas.NoTrans, 293// Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 294// where A is an n×n triangular band matrix, and x and b are vectors. 295// 296// At entry to the function, x contains the values of b, and the result is 297// stored in place into x. 298// 299// No test for singularity or near-singularity is included in this 300// routine. Such tests must be performed before calling this routine. 301func Tbsv(t blas.Transpose, a TriangularBand, x Vector) { 302 blas64.Dtbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 303} 304 305// Tpsv solves 306// A * x = b if t == blas.NoTrans, 307// Aᵀ * x = b if t == blas.Trans or blas.ConjTrans, 308// where A is an n×n triangular matrix in packed format, and x and b are 309// vectors. 310// 311// At entry to the function, x contains the values of b, and the result is 312// stored in place into x. 313// 314// No test for singularity or near-singularity is included in this 315// routine. Such tests must be performed before calling this routine. 316func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) { 317 blas64.Dtpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 318} 319 320// Symv computes 321// y = alpha * A * x + beta * y, 322// where A is an n×n symmetric matrix, x and y are vectors, and alpha and 323// beta are scalars. 324func Symv(alpha float64, a Symmetric, x Vector, beta float64, y Vector) { 325 blas64.Dsymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 326} 327 328// Sbmv performs 329// y = alpha * A * x + beta * y, 330// where A is an n×n symmetric band matrix, x and y are vectors, and alpha 331// and beta are scalars. 332func Sbmv(alpha float64, a SymmetricBand, x Vector, beta float64, y Vector) { 333 blas64.Dsbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 334} 335 336// Spmv performs 337// y = alpha * A * x + beta * y, 338// where A is an n×n symmetric matrix in packed format, x and y are vectors, 339// and alpha and beta are scalars. 340func Spmv(alpha float64, a SymmetricPacked, x Vector, beta float64, y Vector) { 341 blas64.Dspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc) 342} 343 344// Ger performs a rank-1 update 345// A += alpha * x * yᵀ, 346// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 347func Ger(alpha float64, x, y Vector, a General) { 348 blas64.Dger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 349} 350 351// Syr performs a rank-1 update 352// A += alpha * x * xᵀ, 353// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar. 354func Syr(alpha float64, x Vector, a Symmetric) { 355 blas64.Dsyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride) 356} 357 358// Spr performs the rank-1 update 359// A += alpha * x * xᵀ, 360// where A is an n×n symmetric matrix in packed format, x is a vector, and 361// alpha is a scalar. 362func Spr(alpha float64, x Vector, a SymmetricPacked) { 363 blas64.Dspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data) 364} 365 366// Syr2 performs a rank-2 update 367// A += alpha * x * yᵀ + alpha * y * xᵀ, 368// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar. 369func Syr2(alpha float64, x, y Vector, a Symmetric) { 370 blas64.Dsyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 371} 372 373// Spr2 performs a rank-2 update 374// A += alpha * x * yᵀ + alpha * y * xᵀ, 375// where A is an n×n symmetric matrix in packed format, x and y are vectors, 376// and alpha is a scalar. 377func Spr2(alpha float64, x, y Vector, a SymmetricPacked) { 378 blas64.Dspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data) 379} 380 381// Level 3 382 383// Gemm computes 384// C = alpha * A * B + beta * C, 385// where A, B, and C are dense matrices, and alpha and beta are scalars. 386// tA and tB specify whether A or B are transposed. 387func Gemm(tA, tB blas.Transpose, alpha float64, a, b General, beta float64, c General) { 388 var m, n, k int 389 if tA == blas.NoTrans { 390 m, k = a.Rows, a.Cols 391 } else { 392 m, k = a.Cols, a.Rows 393 } 394 if tB == blas.NoTrans { 395 n = b.Cols 396 } else { 397 n = b.Rows 398 } 399 blas64.Dgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 400} 401 402// Symm performs 403// C = alpha * A * B + beta * C if s == blas.Left, 404// C = alpha * B * A + beta * C if s == blas.Right, 405// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and 406// alpha is a scalar. 407func Symm(s blas.Side, alpha float64, a Symmetric, b General, beta float64, c General) { 408 var m, n int 409 if s == blas.Left { 410 m, n = a.N, b.Cols 411 } else { 412 m, n = b.Rows, a.N 413 } 414 blas64.Dsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 415} 416 417// Syrk performs a symmetric rank-k update 418// C = alpha * A * Aᵀ + beta * C if t == blas.NoTrans, 419// C = alpha * Aᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 420// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and 421// a k×n matrix otherwise, and alpha and beta are scalars. 422func Syrk(t blas.Transpose, alpha float64, a General, beta float64, c Symmetric) { 423 var n, k int 424 if t == blas.NoTrans { 425 n, k = a.Rows, a.Cols 426 } else { 427 n, k = a.Cols, a.Rows 428 } 429 blas64.Dsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 430} 431 432// Syr2k performs a symmetric rank-2k update 433// C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C if t == blas.NoTrans, 434// C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C if t == blas.Trans or blas.ConjTrans, 435// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans 436// and k×n matrices otherwise, and alpha and beta are scalars. 437func Syr2k(t blas.Transpose, alpha float64, a, b General, beta float64, c Symmetric) { 438 var n, k int 439 if t == blas.NoTrans { 440 n, k = a.Rows, a.Cols 441 } else { 442 n, k = a.Cols, a.Rows 443 } 444 blas64.Dsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 445} 446 447// Trmm performs 448// B = alpha * A * B if tA == blas.NoTrans and s == blas.Left, 449// B = alpha * Aᵀ * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 450// B = alpha * B * A if tA == blas.NoTrans and s == blas.Right, 451// B = alpha * B * Aᵀ if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 452// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is 453// a scalar. 454func Trmm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) { 455 blas64.Dtrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 456} 457 458// Trsm solves 459// A * X = alpha * B if tA == blas.NoTrans and s == blas.Left, 460// Aᵀ * X = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 461// X * A = alpha * B if tA == blas.NoTrans and s == blas.Right, 462// X * Aᵀ = alpha * B if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 463// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and 464// alpha is a scalar. 465// 466// At entry to the function, X contains the values of B, and the result is 467// stored in-place into X. 468// 469// No check is made that A is invertible. 470func Trsm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) { 471 blas64.Dtrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 472} 473