1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package blas32
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/gonum"
10)
11
12var blas32 blas.Float32 = gonum.Implementation{}
13
14// Use sets the BLAS float32 implementation to be used by subsequent BLAS calls.
15// The default implementation is
16// gonum.org/v1/gonum/blas/gonum.Implementation.
17func Use(b blas.Float32) {
18	blas32 = b
19}
20
21// Implementation returns the current BLAS float32 implementation.
22//
23// Implementation allows direct calls to the current the BLAS float32 implementation
24// giving finer control of parameters.
25func Implementation() blas.Float32 {
26	return blas32
27}
28
29// Vector represents a vector with an associated element increment.
30type Vector struct {
31	Inc  int
32	Data []float32
33}
34
35// General represents a matrix using the conventional storage scheme.
36type General struct {
37	Rows, Cols int
38	Stride     int
39	Data       []float32
40}
41
42// Band represents a band matrix using the band storage scheme.
43type Band struct {
44	Rows, Cols int
45	KL, KU     int
46	Stride     int
47	Data       []float32
48}
49
50// Triangular represents a triangular matrix using the conventional storage scheme.
51type Triangular struct {
52	N      int
53	Stride int
54	Data   []float32
55	Uplo   blas.Uplo
56	Diag   blas.Diag
57}
58
59// TriangularBand represents a triangular matrix using the band storage scheme.
60type TriangularBand struct {
61	N, K   int
62	Stride int
63	Data   []float32
64	Uplo   blas.Uplo
65	Diag   blas.Diag
66}
67
68// TriangularPacked represents a triangular matrix using the packed storage scheme.
69type TriangularPacked struct {
70	N    int
71	Data []float32
72	Uplo blas.Uplo
73	Diag blas.Diag
74}
75
76// Symmetric represents a symmetric matrix using the conventional storage scheme.
77type Symmetric struct {
78	N      int
79	Stride int
80	Data   []float32
81	Uplo   blas.Uplo
82}
83
84// SymmetricBand represents a symmetric matrix using the band storage scheme.
85type SymmetricBand struct {
86	N, K   int
87	Stride int
88	Data   []float32
89	Uplo   blas.Uplo
90}
91
92// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
93type SymmetricPacked struct {
94	N    int
95	Data []float32
96	Uplo blas.Uplo
97}
98
99// Level 1
100
101const negInc = "blas32: negative vector increment"
102
103// Dot computes the dot product of the two vectors:
104//  \sum_i x[i]*y[i].
105func Dot(n int, x, y Vector) float32 {
106	return blas32.Sdot(n, x.Data, x.Inc, y.Data, y.Inc)
107}
108
109// DDot computes the dot product of the two vectors:
110//  \sum_i x[i]*y[i].
111func DDot(n int, x, y Vector) float64 {
112	return blas32.Dsdot(n, x.Data, x.Inc, y.Data, y.Inc)
113}
114
115// SDDot computes the dot product of the two vectors adding a constant:
116//  alpha + \sum_i x[i]*y[i].
117func SDDot(n int, alpha float32, x, y Vector) float32 {
118	return blas32.Sdsdot(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
119}
120
121// Nrm2 computes the Euclidean norm of the vector x:
122//  sqrt(\sum_i x[i]*x[i]).
123//
124// Nrm2 will panic if the vector increment is negative.
125func Nrm2(n int, x Vector) float32 {
126	if x.Inc < 0 {
127		panic(negInc)
128	}
129	return blas32.Snrm2(n, x.Data, x.Inc)
130}
131
132// Asum computes the sum of the absolute values of the elements of x:
133//  \sum_i |x[i]|.
134//
135// Asum will panic if the vector increment is negative.
136func Asum(n int, x Vector) float32 {
137	if x.Inc < 0 {
138		panic(negInc)
139	}
140	return blas32.Sasum(n, x.Data, x.Inc)
141}
142
143// Iamax returns the index of an element of x with the largest absolute value.
144// If there are multiple such indices the earliest is returned.
145// Iamax returns -1 if n == 0.
146//
147// Iamax will panic if the vector increment is negative.
148func Iamax(n int, x Vector) int {
149	if x.Inc < 0 {
150		panic(negInc)
151	}
152	return blas32.Isamax(n, x.Data, x.Inc)
153}
154
155// Swap exchanges the elements of the two vectors:
156//  x[i], y[i] = y[i], x[i] for all i.
157func Swap(n int, x, y Vector) {
158	blas32.Sswap(n, x.Data, x.Inc, y.Data, y.Inc)
159}
160
161// Copy copies the elements of x into the elements of y:
162//  y[i] = x[i] for all i.
163func Copy(n int, x, y Vector) {
164	blas32.Scopy(n, x.Data, x.Inc, y.Data, y.Inc)
165}
166
167// Axpy adds x scaled by alpha to y:
168//  y[i] += alpha*x[i] for all i.
169func Axpy(n int, alpha float32, x, y Vector) {
170	blas32.Saxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
171}
172
173// Rotg computes the parameters of a Givens plane rotation so that
174//  ⎡ c s⎤   ⎡a⎤   ⎡r⎤
175//  ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
176// where a and b are the Cartesian coordinates of a given point.
177// c, s, and r are defined as
178//  r = ±Sqrt(a^2 + b^2),
179//  c = a/r, the cosine of the rotation angle,
180//  s = a/r, the sine of the rotation angle,
181// and z is defined such that
182//  if |a| > |b|,        z = s,
183//  otherwise if c != 0, z = 1/c,
184//  otherwise            z = 1.
185func Rotg(a, b float32) (c, s, r, z float32) {
186	return blas32.Srotg(a, b)
187}
188
189// Rotmg computes the modified Givens rotation. See
190// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
191// for more details.
192func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) {
193	return blas32.Srotmg(d1, d2, b1, b2)
194}
195
196// Rot applies a plane transformation to n points represented by the vectors x
197// and y:
198//  x[i] =  c*x[i] + s*y[i],
199//  y[i] = -s*x[i] + c*y[i], for all i.
200func Rot(n int, x, y Vector, c, s float32) {
201	blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
202}
203
204// Rotm applies the modified Givens rotation to n points represented by the
205// vectors x and y.
206func Rotm(n int, x, y Vector, p blas.SrotmParams) {
207	blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
208}
209
210// Scal scales the vector x by alpha:
211//  x[i] *= alpha for all i.
212//
213// Scal will panic if the vector increment is negative.
214func Scal(n int, alpha float32, x Vector) {
215	if x.Inc < 0 {
216		panic(negInc)
217	}
218	blas32.Sscal(n, alpha, x.Data, x.Inc)
219}
220
221// Level 2
222
223// Gemv computes
224//  y = alpha * A * x + beta * y,   if t == blas.NoTrans,
225//  y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
226// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
227func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) {
228	blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
229}
230
231// Gbmv computes
232//  y = alpha * A * x + beta * y,   if t == blas.NoTrans,
233//  y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
234// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
235func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) {
236	blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
237}
238
239// Trmv computes
240//  x = A * x,   if t == blas.NoTrans,
241//  x = A^T * x, if t == blas.Trans or blas.ConjTrans,
242// where A is an n×n triangular matrix, and x is a vector.
243func Trmv(t blas.Transpose, a Triangular, x Vector) {
244	blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
245}
246
247// Tbmv computes
248//  x = A * x,   if t == blas.NoTrans,
249//  x = A^T * x, if t == blas.Trans or blas.ConjTrans,
250// where A is an n×n triangular band matrix, and x is a vector.
251func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
252	blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
253}
254
255// Tpmv computes
256//  x = A * x,   if t == blas.NoTrans,
257//  x = A^T * x, if t == blas.Trans or blas.ConjTrans,
258// where A is an n×n triangular matrix in packed format, and x is a vector.
259func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
260	blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
261}
262
263// Trsv solves
264//  A * x = b,   if t == blas.NoTrans,
265//  A^T * x = b, if t == blas.Trans or blas.ConjTrans,
266// where A is an n×n triangular matrix, and x and b are vectors.
267//
268// At entry to the function, x contains the values of b, and the result is
269// stored in-place into x.
270//
271// No test for singularity or near-singularity is included in this
272// routine. Such tests must be performed before calling this routine.
273func Trsv(t blas.Transpose, a Triangular, x Vector) {
274	blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
275}
276
277// Tbsv solves
278//  A * x = b,   if t == blas.NoTrans,
279//  A^T * x = b, if t == blas.Trans or blas.ConjTrans,
280// where A is an n×n triangular band matrix, and x and b are vectors.
281//
282// At entry to the function, x contains the values of b, and the result is
283// stored in place into x.
284//
285// No test for singularity or near-singularity is included in this
286// routine. Such tests must be performed before calling this routine.
287func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
288	blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
289}
290
291// Tpsv solves
292//  A * x = b,   if t == blas.NoTrans,
293//  A^T * x = b, if t == blas.Trans or blas.ConjTrans,
294// where A is an n×n triangular matrix in packed format, and x and b are
295// vectors.
296//
297// At entry to the function, x contains the values of b, and the result is
298// stored in place into x.
299//
300// No test for singularity or near-singularity is included in this
301// routine. Such tests must be performed before calling this routine.
302func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
303	blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
304}
305
306// Symv computes
307//    y = alpha * A * x + beta * y,
308// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
309// beta are scalars.
310func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) {
311	blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
312}
313
314// Sbmv performs
315//  y = alpha * A * x + beta * y,
316// where A is an n×n symmetric band matrix, x and y are vectors, and alpha
317// and beta are scalars.
318func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) {
319	blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
320}
321
322// Spmv performs
323//    y = alpha * A * x + beta * y,
324// where A is an n×n symmetric matrix in packed format, x and y are vectors,
325// and alpha and beta are scalars.
326func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) {
327	blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
328}
329
330// Ger performs a rank-1 update
331//  A += alpha * x * y^T,
332// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
333func Ger(alpha float32, x, y Vector, a General) {
334	blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
335}
336
337// Syr performs a rank-1 update
338//  A += alpha * x * x^T,
339// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
340func Syr(alpha float32, x Vector, a Symmetric) {
341	blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
342}
343
344// Spr performs the rank-1 update
345//  A += alpha * x * x^T,
346// where A is an n×n symmetric matrix in packed format, x is a vector, and
347// alpha is a scalar.
348func Spr(alpha float32, x Vector, a SymmetricPacked) {
349	blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
350}
351
352// Syr2 performs a rank-2 update
353//  A += alpha * x * y^T + alpha * y * x^T,
354// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
355func Syr2(alpha float32, x, y Vector, a Symmetric) {
356	blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
357}
358
359// Spr2 performs a rank-2 update
360//  A += alpha * x * y^T + alpha * y * x^T,
361// where A is an n×n symmetric matrix in packed format, x and y are vectors,
362// and alpha is a scalar.
363func Spr2(alpha float32, x, y Vector, a SymmetricPacked) {
364	blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
365}
366
367// Level 3
368
369// Gemm computes
370//  C = alpha * A * B + beta * C,
371// where A, B, and C are dense matrices, and alpha and beta are scalars.
372// tA and tB specify whether A or B are transposed.
373func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) {
374	var m, n, k int
375	if tA == blas.NoTrans {
376		m, k = a.Rows, a.Cols
377	} else {
378		m, k = a.Cols, a.Rows
379	}
380	if tB == blas.NoTrans {
381		n = b.Cols
382	} else {
383		n = b.Rows
384	}
385	blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
386}
387
388// Symm performs
389//  C = alpha * A * B + beta * C, if s == blas.Left,
390//  C = alpha * B * A + beta * C, if s == blas.Right,
391// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
392// alpha is a scalar.
393func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) {
394	var m, n int
395	if s == blas.Left {
396		m, n = a.N, b.Cols
397	} else {
398		m, n = b.Rows, a.N
399	}
400	blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
401}
402
403// Syrk performs a symmetric rank-k update
404//  C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
405//  C = alpha * A^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
406// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
407// a k×n matrix otherwise, and alpha and beta are scalars.
408func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) {
409	var n, k int
410	if t == blas.NoTrans {
411		n, k = a.Rows, a.Cols
412	} else {
413		n, k = a.Cols, a.Rows
414	}
415	blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
416}
417
418// Syr2k performs a symmetric rank-2k update
419//  C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
420//  C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
421// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
422// and k×n matrices otherwise, and alpha and beta are scalars.
423func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) {
424	var n, k int
425	if t == blas.NoTrans {
426		n, k = a.Rows, a.Cols
427	} else {
428		n, k = a.Cols, a.Rows
429	}
430	blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
431}
432
433// Trmm performs
434//  B = alpha * A * B,   if tA == blas.NoTrans and s == blas.Left,
435//  B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
436//  B = alpha * B * A,   if tA == blas.NoTrans and s == blas.Right,
437//  B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
438// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
439// a scalar.
440func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
441	blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
442}
443
444// Trsm solves
445//  A * X = alpha * B,   if tA == blas.NoTrans and s == blas.Left,
446//  A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
447//  X * A = alpha * B,   if tA == blas.NoTrans and s == blas.Right,
448//  X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
449// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
450// alpha is a scalar.
451//
452// At entry to the function, X contains the values of B, and the result is
453// stored in-place into X.
454//
455// No check is made that A is invertible.
456func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
457	blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
458}
459