1// Copyright ©2015 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package blas32 6 7import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/gonum" 10) 11 12var blas32 blas.Float32 = gonum.Implementation{} 13 14// Use sets the BLAS float32 implementation to be used by subsequent BLAS calls. 15// The default implementation is 16// gonum.org/v1/gonum/blas/gonum.Implementation. 17func Use(b blas.Float32) { 18 blas32 = b 19} 20 21// Implementation returns the current BLAS float32 implementation. 22// 23// Implementation allows direct calls to the current the BLAS float32 implementation 24// giving finer control of parameters. 25func Implementation() blas.Float32 { 26 return blas32 27} 28 29// Vector represents a vector with an associated element increment. 30type Vector struct { 31 Inc int 32 Data []float32 33} 34 35// General represents a matrix using the conventional storage scheme. 36type General struct { 37 Rows, Cols int 38 Stride int 39 Data []float32 40} 41 42// Band represents a band matrix using the band storage scheme. 43type Band struct { 44 Rows, Cols int 45 KL, KU int 46 Stride int 47 Data []float32 48} 49 50// Triangular represents a triangular matrix using the conventional storage scheme. 51type Triangular struct { 52 N int 53 Stride int 54 Data []float32 55 Uplo blas.Uplo 56 Diag blas.Diag 57} 58 59// TriangularBand represents a triangular matrix using the band storage scheme. 60type TriangularBand struct { 61 N, K int 62 Stride int 63 Data []float32 64 Uplo blas.Uplo 65 Diag blas.Diag 66} 67 68// TriangularPacked represents a triangular matrix using the packed storage scheme. 69type TriangularPacked struct { 70 N int 71 Data []float32 72 Uplo blas.Uplo 73 Diag blas.Diag 74} 75 76// Symmetric represents a symmetric matrix using the conventional storage scheme. 77type Symmetric struct { 78 N int 79 Stride int 80 Data []float32 81 Uplo blas.Uplo 82} 83 84// SymmetricBand represents a symmetric matrix using the band storage scheme. 85type SymmetricBand struct { 86 N, K int 87 Stride int 88 Data []float32 89 Uplo blas.Uplo 90} 91 92// SymmetricPacked represents a symmetric matrix using the packed storage scheme. 93type SymmetricPacked struct { 94 N int 95 Data []float32 96 Uplo blas.Uplo 97} 98 99// Level 1 100 101const negInc = "blas32: negative vector increment" 102 103// Dot computes the dot product of the two vectors: 104// \sum_i x[i]*y[i]. 105func Dot(n int, x, y Vector) float32 { 106 return blas32.Sdot(n, x.Data, x.Inc, y.Data, y.Inc) 107} 108 109// DDot computes the dot product of the two vectors: 110// \sum_i x[i]*y[i]. 111func DDot(n int, x, y Vector) float64 { 112 return blas32.Dsdot(n, x.Data, x.Inc, y.Data, y.Inc) 113} 114 115// SDDot computes the dot product of the two vectors adding a constant: 116// alpha + \sum_i x[i]*y[i]. 117func SDDot(n int, alpha float32, x, y Vector) float32 { 118 return blas32.Sdsdot(n, alpha, x.Data, x.Inc, y.Data, y.Inc) 119} 120 121// Nrm2 computes the Euclidean norm of the vector x: 122// sqrt(\sum_i x[i]*x[i]). 123// 124// Nrm2 will panic if the vector increment is negative. 125func Nrm2(n int, x Vector) float32 { 126 if x.Inc < 0 { 127 panic(negInc) 128 } 129 return blas32.Snrm2(n, x.Data, x.Inc) 130} 131 132// Asum computes the sum of the absolute values of the elements of x: 133// \sum_i |x[i]|. 134// 135// Asum will panic if the vector increment is negative. 136func Asum(n int, x Vector) float32 { 137 if x.Inc < 0 { 138 panic(negInc) 139 } 140 return blas32.Sasum(n, x.Data, x.Inc) 141} 142 143// Iamax returns the index of an element of x with the largest absolute value. 144// If there are multiple such indices the earliest is returned. 145// Iamax returns -1 if n == 0. 146// 147// Iamax will panic if the vector increment is negative. 148func Iamax(n int, x Vector) int { 149 if x.Inc < 0 { 150 panic(negInc) 151 } 152 return blas32.Isamax(n, x.Data, x.Inc) 153} 154 155// Swap exchanges the elements of the two vectors: 156// x[i], y[i] = y[i], x[i] for all i. 157func Swap(n int, x, y Vector) { 158 blas32.Sswap(n, x.Data, x.Inc, y.Data, y.Inc) 159} 160 161// Copy copies the elements of x into the elements of y: 162// y[i] = x[i] for all i. 163func Copy(n int, x, y Vector) { 164 blas32.Scopy(n, x.Data, x.Inc, y.Data, y.Inc) 165} 166 167// Axpy adds x scaled by alpha to y: 168// y[i] += alpha*x[i] for all i. 169func Axpy(n int, alpha float32, x, y Vector) { 170 blas32.Saxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc) 171} 172 173// Rotg computes the parameters of a Givens plane rotation so that 174// ⎡ c s⎤ ⎡a⎤ ⎡r⎤ 175// ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦ 176// where a and b are the Cartesian coordinates of a given point. 177// c, s, and r are defined as 178// r = ±Sqrt(a^2 + b^2), 179// c = a/r, the cosine of the rotation angle, 180// s = a/r, the sine of the rotation angle, 181// and z is defined such that 182// if |a| > |b|, z = s, 183// otherwise if c != 0, z = 1/c, 184// otherwise z = 1. 185func Rotg(a, b float32) (c, s, r, z float32) { 186 return blas32.Srotg(a, b) 187} 188 189// Rotmg computes the modified Givens rotation. See 190// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html 191// for more details. 192func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) { 193 return blas32.Srotmg(d1, d2, b1, b2) 194} 195 196// Rot applies a plane transformation to n points represented by the vectors x 197// and y: 198// x[i] = c*x[i] + s*y[i], 199// y[i] = -s*x[i] + c*y[i], for all i. 200func Rot(n int, x, y Vector, c, s float32) { 201 blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s) 202} 203 204// Rotm applies the modified Givens rotation to n points represented by the 205// vectors x and y. 206func Rotm(n int, x, y Vector, p blas.SrotmParams) { 207 blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p) 208} 209 210// Scal scales the vector x by alpha: 211// x[i] *= alpha for all i. 212// 213// Scal will panic if the vector increment is negative. 214func Scal(n int, alpha float32, x Vector) { 215 if x.Inc < 0 { 216 panic(negInc) 217 } 218 blas32.Sscal(n, alpha, x.Data, x.Inc) 219} 220 221// Level 2 222 223// Gemv computes 224// y = alpha * A * x + beta * y, if t == blas.NoTrans, 225// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans, 226// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars. 227func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) { 228 blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 229} 230 231// Gbmv computes 232// y = alpha * A * x + beta * y, if t == blas.NoTrans, 233// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans, 234// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars. 235func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) { 236 blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 237} 238 239// Trmv computes 240// x = A * x, if t == blas.NoTrans, 241// x = A^T * x, if t == blas.Trans or blas.ConjTrans, 242// where A is an n×n triangular matrix, and x is a vector. 243func Trmv(t blas.Transpose, a Triangular, x Vector) { 244 blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 245} 246 247// Tbmv computes 248// x = A * x, if t == blas.NoTrans, 249// x = A^T * x, if t == blas.Trans or blas.ConjTrans, 250// where A is an n×n triangular band matrix, and x is a vector. 251func Tbmv(t blas.Transpose, a TriangularBand, x Vector) { 252 blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 253} 254 255// Tpmv computes 256// x = A * x, if t == blas.NoTrans, 257// x = A^T * x, if t == blas.Trans or blas.ConjTrans, 258// where A is an n×n triangular matrix in packed format, and x is a vector. 259func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) { 260 blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 261} 262 263// Trsv solves 264// A * x = b, if t == blas.NoTrans, 265// A^T * x = b, if t == blas.Trans or blas.ConjTrans, 266// where A is an n×n triangular matrix, and x and b are vectors. 267// 268// At entry to the function, x contains the values of b, and the result is 269// stored in-place into x. 270// 271// No test for singularity or near-singularity is included in this 272// routine. Such tests must be performed before calling this routine. 273func Trsv(t blas.Transpose, a Triangular, x Vector) { 274 blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc) 275} 276 277// Tbsv solves 278// A * x = b, if t == blas.NoTrans, 279// A^T * x = b, if t == blas.Trans or blas.ConjTrans, 280// where A is an n×n triangular band matrix, and x and b are vectors. 281// 282// At entry to the function, x contains the values of b, and the result is 283// stored in place into x. 284// 285// No test for singularity or near-singularity is included in this 286// routine. Such tests must be performed before calling this routine. 287func Tbsv(t blas.Transpose, a TriangularBand, x Vector) { 288 blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc) 289} 290 291// Tpsv solves 292// A * x = b, if t == blas.NoTrans, 293// A^T * x = b, if t == blas.Trans or blas.ConjTrans, 294// where A is an n×n triangular matrix in packed format, and x and b are 295// vectors. 296// 297// At entry to the function, x contains the values of b, and the result is 298// stored in place into x. 299// 300// No test for singularity or near-singularity is included in this 301// routine. Such tests must be performed before calling this routine. 302func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) { 303 blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc) 304} 305 306// Symv computes 307// y = alpha * A * x + beta * y, 308// where A is an n×n symmetric matrix, x and y are vectors, and alpha and 309// beta are scalars. 310func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) { 311 blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 312} 313 314// Sbmv performs 315// y = alpha * A * x + beta * y, 316// where A is an n×n symmetric band matrix, x and y are vectors, and alpha 317// and beta are scalars. 318func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) { 319 blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc) 320} 321 322// Spmv performs 323// y = alpha * A * x + beta * y, 324// where A is an n×n symmetric matrix in packed format, x and y are vectors, 325// and alpha and beta are scalars. 326func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) { 327 blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc) 328} 329 330// Ger performs a rank-1 update 331// A += alpha * x * y^T, 332// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar. 333func Ger(alpha float32, x, y Vector, a General) { 334 blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 335} 336 337// Syr performs a rank-1 update 338// A += alpha * x * x^T, 339// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar. 340func Syr(alpha float32, x Vector, a Symmetric) { 341 blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride) 342} 343 344// Spr performs the rank-1 update 345// A += alpha * x * x^T, 346// where A is an n×n symmetric matrix in packed format, x is a vector, and 347// alpha is a scalar. 348func Spr(alpha float32, x Vector, a SymmetricPacked) { 349 blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data) 350} 351 352// Syr2 performs a rank-2 update 353// A += alpha * x * y^T + alpha * y * x^T, 354// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar. 355func Syr2(alpha float32, x, y Vector, a Symmetric) { 356 blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride) 357} 358 359// Spr2 performs a rank-2 update 360// A += alpha * x * y^T + alpha * y * x^T, 361// where A is an n×n symmetric matrix in packed format, x and y are vectors, 362// and alpha is a scalar. 363func Spr2(alpha float32, x, y Vector, a SymmetricPacked) { 364 blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data) 365} 366 367// Level 3 368 369// Gemm computes 370// C = alpha * A * B + beta * C, 371// where A, B, and C are dense matrices, and alpha and beta are scalars. 372// tA and tB specify whether A or B are transposed. 373func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) { 374 var m, n, k int 375 if tA == blas.NoTrans { 376 m, k = a.Rows, a.Cols 377 } else { 378 m, k = a.Cols, a.Rows 379 } 380 if tB == blas.NoTrans { 381 n = b.Cols 382 } else { 383 n = b.Rows 384 } 385 blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 386} 387 388// Symm performs 389// C = alpha * A * B + beta * C, if s == blas.Left, 390// C = alpha * B * A + beta * C, if s == blas.Right, 391// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and 392// alpha is a scalar. 393func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) { 394 var m, n int 395 if s == blas.Left { 396 m, n = a.N, b.Cols 397 } else { 398 m, n = b.Rows, a.N 399 } 400 blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 401} 402 403// Syrk performs a symmetric rank-k update 404// C = alpha * A * A^T + beta * C, if t == blas.NoTrans, 405// C = alpha * A^T * A + beta * C, if t == blas.Trans or blas.ConjTrans, 406// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and 407// a k×n matrix otherwise, and alpha and beta are scalars. 408func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) { 409 var n, k int 410 if t == blas.NoTrans { 411 n, k = a.Rows, a.Cols 412 } else { 413 n, k = a.Cols, a.Rows 414 } 415 blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride) 416} 417 418// Syr2k performs a symmetric rank-2k update 419// C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans, 420// C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans or blas.ConjTrans, 421// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans 422// and k×n matrices otherwise, and alpha and beta are scalars. 423func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) { 424 var n, k int 425 if t == blas.NoTrans { 426 n, k = a.Rows, a.Cols 427 } else { 428 n, k = a.Cols, a.Rows 429 } 430 blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride) 431} 432 433// Trmm performs 434// B = alpha * A * B, if tA == blas.NoTrans and s == blas.Left, 435// B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 436// B = alpha * B * A, if tA == blas.NoTrans and s == blas.Right, 437// B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 438// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is 439// a scalar. 440func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 441 blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 442} 443 444// Trsm solves 445// A * X = alpha * B, if tA == blas.NoTrans and s == blas.Left, 446// A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left, 447// X * A = alpha * B, if tA == blas.NoTrans and s == blas.Right, 448// X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right, 449// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and 450// alpha is a scalar. 451// 452// At entry to the function, X contains the values of B, and the result is 453// stored in-place into X. 454// 455// No check is made that A is invertible. 456func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) { 457 blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride) 458} 459