1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package blas64
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/gonum"
10)
11
12var blas64 blas.Float64 = gonum.Implementation{}
13
14// Use sets the BLAS float64 implementation to be used by subsequent BLAS calls.
15// The default implementation is
16// gonum.org/v1/gonum/blas/gonum.Implementation.
17func Use(b blas.Float64) {
18	blas64 = b
19}
20
21// Implementation returns the current BLAS float64 implementation.
22//
23// Implementation allows direct calls to the current the BLAS float64 implementation
24// giving finer control of parameters.
25func Implementation() blas.Float64 {
26	return blas64
27}
28
29// Vector represents a vector with an associated element increment.
30type Vector struct {
31	N    int
32	Data []float64
33	Inc  int
34}
35
36// General represents a matrix using the conventional storage scheme.
37type General struct {
38	Rows, Cols int
39	Data       []float64
40	Stride     int
41}
42
43// Band represents a band matrix using the band storage scheme.
44type Band struct {
45	Rows, Cols int
46	KL, KU     int
47	Data       []float64
48	Stride     int
49}
50
51// Triangular represents a triangular matrix using the conventional storage scheme.
52type Triangular struct {
53	Uplo   blas.Uplo
54	Diag   blas.Diag
55	N      int
56	Data   []float64
57	Stride int
58}
59
60// TriangularBand represents a triangular matrix using the band storage scheme.
61type TriangularBand struct {
62	Uplo   blas.Uplo
63	Diag   blas.Diag
64	N, K   int
65	Data   []float64
66	Stride int
67}
68
69// TriangularPacked represents a triangular matrix using the packed storage scheme.
70type TriangularPacked struct {
71	Uplo blas.Uplo
72	Diag blas.Diag
73	N    int
74	Data []float64
75}
76
77// Symmetric represents a symmetric matrix using the conventional storage scheme.
78type Symmetric struct {
79	Uplo   blas.Uplo
80	N      int
81	Data   []float64
82	Stride int
83}
84
85// SymmetricBand represents a symmetric matrix using the band storage scheme.
86type SymmetricBand struct {
87	Uplo   blas.Uplo
88	N, K   int
89	Data   []float64
90	Stride int
91}
92
93// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
94type SymmetricPacked struct {
95	Uplo blas.Uplo
96	N    int
97	Data []float64
98}
99
100// Level 1
101
102const (
103	negInc    = "blas64: negative vector increment"
104	badLength = "blas64: vector length mismatch"
105)
106
107// Dot computes the dot product of the two vectors:
108//  \sum_i x[i]*y[i].
109// Dot will panic if the lengths of x and y do not match.
110func Dot(x, y Vector) float64 {
111	if x.N != y.N {
112		panic(badLength)
113	}
114	return blas64.Ddot(x.N, x.Data, x.Inc, y.Data, y.Inc)
115}
116
117// Nrm2 computes the Euclidean norm of the vector x:
118//  sqrt(\sum_i x[i]*x[i]).
119//
120// Nrm2 will panic if the vector increment is negative.
121func Nrm2(x Vector) float64 {
122	if x.Inc < 0 {
123		panic(negInc)
124	}
125	return blas64.Dnrm2(x.N, x.Data, x.Inc)
126}
127
128// Asum computes the sum of the absolute values of the elements of x:
129//  \sum_i |x[i]|.
130//
131// Asum will panic if the vector increment is negative.
132func Asum(x Vector) float64 {
133	if x.Inc < 0 {
134		panic(negInc)
135	}
136	return blas64.Dasum(x.N, x.Data, x.Inc)
137}
138
139// Iamax returns the index of an element of x with the largest absolute value.
140// If there are multiple such indices the earliest is returned.
141// Iamax returns -1 if n == 0.
142//
143// Iamax will panic if the vector increment is negative.
144func Iamax(x Vector) int {
145	if x.Inc < 0 {
146		panic(negInc)
147	}
148	return blas64.Idamax(x.N, x.Data, x.Inc)
149}
150
151// Swap exchanges the elements of the two vectors:
152//  x[i], y[i] = y[i], x[i] for all i.
153// Swap will panic if the lengths of x and y do not match.
154func Swap(x, y Vector) {
155	if x.N != y.N {
156		panic(badLength)
157	}
158	blas64.Dswap(x.N, x.Data, x.Inc, y.Data, y.Inc)
159}
160
161// Copy copies the elements of x into the elements of y:
162//  y[i] = x[i] for all i.
163// Copy will panic if the lengths of x and y do not match.
164func Copy(x, y Vector) {
165	if x.N != y.N {
166		panic(badLength)
167	}
168	blas64.Dcopy(x.N, x.Data, x.Inc, y.Data, y.Inc)
169}
170
171// Axpy adds x scaled by alpha to y:
172//  y[i] += alpha*x[i] for all i.
173// Axpy will panic if the lengths of x and y do not match.
174func Axpy(alpha float64, x, y Vector) {
175	if x.N != y.N {
176		panic(badLength)
177	}
178	blas64.Daxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
179}
180
181// Rotg computes the parameters of a Givens plane rotation so that
182//  ⎡ c s⎤   ⎡a⎤   ⎡r⎤
183//  ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
184// where a and b are the Cartesian coordinates of a given point.
185// c, s, and r are defined as
186//  r = ±Sqrt(a^2 + b^2),
187//  c = a/r, the cosine of the rotation angle,
188//  s = a/r, the sine of the rotation angle,
189// and z is defined such that
190//  if |a| > |b|,        z = s,
191//  otherwise if c != 0, z = 1/c,
192//  otherwise            z = 1.
193func Rotg(a, b float64) (c, s, r, z float64) {
194	return blas64.Drotg(a, b)
195}
196
197// Rotmg computes the modified Givens rotation. See
198// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
199// for more details.
200func Rotmg(d1, d2, b1, b2 float64) (p blas.DrotmParams, rd1, rd2, rb1 float64) {
201	return blas64.Drotmg(d1, d2, b1, b2)
202}
203
204// Rot applies a plane transformation to n points represented by the vectors x
205// and y:
206//  x[i] =  c*x[i] + s*y[i],
207//  y[i] = -s*x[i] + c*y[i], for all i.
208func Rot(x, y Vector, c, s float64) {
209	if x.N != y.N {
210		panic(badLength)
211	}
212	blas64.Drot(x.N, x.Data, x.Inc, y.Data, y.Inc, c, s)
213}
214
215// Rotm applies the modified Givens rotation to n points represented by the
216// vectors x and y.
217func Rotm(x, y Vector, p blas.DrotmParams) {
218	if x.N != y.N {
219		panic(badLength)
220	}
221	blas64.Drotm(x.N, x.Data, x.Inc, y.Data, y.Inc, p)
222}
223
224// Scal scales the vector x by alpha:
225//  x[i] *= alpha for all i.
226//
227// Scal will panic if the vector increment is negative.
228func Scal(alpha float64, x Vector) {
229	if x.Inc < 0 {
230		panic(negInc)
231	}
232	blas64.Dscal(x.N, alpha, x.Data, x.Inc)
233}
234
235// Level 2
236
237// Gemv computes
238//  y = alpha * A * x + beta * y   if t == blas.NoTrans,
239//  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
240// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
241func Gemv(t blas.Transpose, alpha float64, a General, x Vector, beta float64, y Vector) {
242	blas64.Dgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
243}
244
245// Gbmv computes
246//  y = alpha * A * x + beta * y   if t == blas.NoTrans,
247//  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans or blas.ConjTrans,
248// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
249func Gbmv(t blas.Transpose, alpha float64, a Band, x Vector, beta float64, y Vector) {
250	blas64.Dgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
251}
252
253// Trmv computes
254//  x = A * x   if t == blas.NoTrans,
255//  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
256// where A is an n×n triangular matrix, and x is a vector.
257func Trmv(t blas.Transpose, a Triangular, x Vector) {
258	blas64.Dtrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
259}
260
261// Tbmv computes
262//  x = A * x   if t == blas.NoTrans,
263//  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
264// where A is an n×n triangular band matrix, and x is a vector.
265func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
266	blas64.Dtbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
267}
268
269// Tpmv computes
270//  x = A * x   if t == blas.NoTrans,
271//  x = Aᵀ * x  if t == blas.Trans or blas.ConjTrans,
272// where A is an n×n triangular matrix in packed format, and x is a vector.
273func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
274	blas64.Dtpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
275}
276
277// Trsv solves
278//  A * x = b   if t == blas.NoTrans,
279//  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
280// where A is an n×n triangular matrix, and x and b are vectors.
281//
282// At entry to the function, x contains the values of b, and the result is
283// stored in-place into x.
284//
285// No test for singularity or near-singularity is included in this
286// routine. Such tests must be performed before calling this routine.
287func Trsv(t blas.Transpose, a Triangular, x Vector) {
288	blas64.Dtrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
289}
290
291// Tbsv solves
292//  A * x = b   if t == blas.NoTrans,
293//  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
294// where A is an n×n triangular band matrix, and x and b are vectors.
295//
296// At entry to the function, x contains the values of b, and the result is
297// stored in place into x.
298//
299// No test for singularity or near-singularity is included in this
300// routine. Such tests must be performed before calling this routine.
301func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
302	blas64.Dtbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
303}
304
305// Tpsv solves
306//  A * x = b   if t == blas.NoTrans,
307//  Aᵀ * x = b  if t == blas.Trans or blas.ConjTrans,
308// where A is an n×n triangular matrix in packed format, and x and b are
309// vectors.
310//
311// At entry to the function, x contains the values of b, and the result is
312// stored in place into x.
313//
314// No test for singularity or near-singularity is included in this
315// routine. Such tests must be performed before calling this routine.
316func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
317	blas64.Dtpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
318}
319
320// Symv computes
321//  y = alpha * A * x + beta * y,
322// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
323// beta are scalars.
324func Symv(alpha float64, a Symmetric, x Vector, beta float64, y Vector) {
325	blas64.Dsymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
326}
327
328// Sbmv performs
329//  y = alpha * A * x + beta * y,
330// where A is an n×n symmetric band matrix, x and y are vectors, and alpha
331// and beta are scalars.
332func Sbmv(alpha float64, a SymmetricBand, x Vector, beta float64, y Vector) {
333	blas64.Dsbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
334}
335
336// Spmv performs
337//  y = alpha * A * x + beta * y,
338// where A is an n×n symmetric matrix in packed format, x and y are vectors,
339// and alpha and beta are scalars.
340func Spmv(alpha float64, a SymmetricPacked, x Vector, beta float64, y Vector) {
341	blas64.Dspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
342}
343
344// Ger performs a rank-1 update
345//  A += alpha * x * yᵀ,
346// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
347func Ger(alpha float64, x, y Vector, a General) {
348	blas64.Dger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
349}
350
351// Syr performs a rank-1 update
352//  A += alpha * x * xᵀ,
353// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
354func Syr(alpha float64, x Vector, a Symmetric) {
355	blas64.Dsyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
356}
357
358// Spr performs the rank-1 update
359//  A += alpha * x * xᵀ,
360// where A is an n×n symmetric matrix in packed format, x is a vector, and
361// alpha is a scalar.
362func Spr(alpha float64, x Vector, a SymmetricPacked) {
363	blas64.Dspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
364}
365
366// Syr2 performs a rank-2 update
367//  A += alpha * x * yᵀ + alpha * y * xᵀ,
368// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
369func Syr2(alpha float64, x, y Vector, a Symmetric) {
370	blas64.Dsyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
371}
372
373// Spr2 performs a rank-2 update
374//  A += alpha * x * yᵀ + alpha * y * xᵀ,
375// where A is an n×n symmetric matrix in packed format, x and y are vectors,
376// and alpha is a scalar.
377func Spr2(alpha float64, x, y Vector, a SymmetricPacked) {
378	blas64.Dspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
379}
380
381// Level 3
382
383// Gemm computes
384//  C = alpha * A * B + beta * C,
385// where A, B, and C are dense matrices, and alpha and beta are scalars.
386// tA and tB specify whether A or B are transposed.
387func Gemm(tA, tB blas.Transpose, alpha float64, a, b General, beta float64, c General) {
388	var m, n, k int
389	if tA == blas.NoTrans {
390		m, k = a.Rows, a.Cols
391	} else {
392		m, k = a.Cols, a.Rows
393	}
394	if tB == blas.NoTrans {
395		n = b.Cols
396	} else {
397		n = b.Rows
398	}
399	blas64.Dgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
400}
401
402// Symm performs
403//  C = alpha * A * B + beta * C  if s == blas.Left,
404//  C = alpha * B * A + beta * C  if s == blas.Right,
405// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
406// alpha is a scalar.
407func Symm(s blas.Side, alpha float64, a Symmetric, b General, beta float64, c General) {
408	var m, n int
409	if s == blas.Left {
410		m, n = a.N, b.Cols
411	} else {
412		m, n = b.Rows, a.N
413	}
414	blas64.Dsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
415}
416
417// Syrk performs a symmetric rank-k update
418//  C = alpha * A * Aᵀ + beta * C  if t == blas.NoTrans,
419//  C = alpha * Aᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
420// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
421// a k×n matrix otherwise, and alpha and beta are scalars.
422func Syrk(t blas.Transpose, alpha float64, a General, beta float64, c Symmetric) {
423	var n, k int
424	if t == blas.NoTrans {
425		n, k = a.Rows, a.Cols
426	} else {
427		n, k = a.Cols, a.Rows
428	}
429	blas64.Dsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
430}
431
432// Syr2k performs a symmetric rank-2k update
433//  C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C  if t == blas.NoTrans,
434//  C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C  if t == blas.Trans or blas.ConjTrans,
435// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
436// and k×n matrices otherwise, and alpha and beta are scalars.
437func Syr2k(t blas.Transpose, alpha float64, a, b General, beta float64, c Symmetric) {
438	var n, k int
439	if t == blas.NoTrans {
440		n, k = a.Rows, a.Cols
441	} else {
442		n, k = a.Cols, a.Rows
443	}
444	blas64.Dsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
445}
446
447// Trmm performs
448//  B = alpha * A * B   if tA == blas.NoTrans and s == blas.Left,
449//  B = alpha * Aᵀ * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
450//  B = alpha * B * A   if tA == blas.NoTrans and s == blas.Right,
451//  B = alpha * B * Aᵀ  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
452// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
453// a scalar.
454func Trmm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
455	blas64.Dtrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
456}
457
458// Trsm solves
459//  A * X = alpha * B   if tA == blas.NoTrans and s == blas.Left,
460//  Aᵀ * X = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
461//  X * A = alpha * B   if tA == blas.NoTrans and s == blas.Right,
462//  X * Aᵀ = alpha * B  if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
463// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
464// alpha is a scalar.
465//
466// At entry to the function, X contains the values of B, and the result is
467// stored in-place into X.
468//
469// No check is made that A is invertible.
470func Trsm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
471	blas64.Dtrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
472}
473