1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cblas64
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/gonum"
10)
11
12var cblas64 blas.Complex64 = gonum.Implementation{}
13
14// Use sets the BLAS complex64 implementation to be used by subsequent BLAS calls.
15// The default implementation is
16// gonum.org/v1/gonum/blas/gonum.Implementation.
17func Use(b blas.Complex64) {
18	cblas64 = b
19}
20
21// Implementation returns the current BLAS complex64 implementation.
22//
23// Implementation allows direct calls to the current the BLAS complex64 implementation
24// giving finer control of parameters.
25func Implementation() blas.Complex64 {
26	return cblas64
27}
28
29// Vector represents a vector with an associated element increment.
30type Vector struct {
31	Inc  int
32	Data []complex64
33}
34
35// General represents a matrix using the conventional storage scheme.
36type General struct {
37	Rows, Cols int
38	Stride     int
39	Data       []complex64
40}
41
42// Band represents a band matrix using the band storage scheme.
43type Band struct {
44	Rows, Cols int
45	KL, KU     int
46	Stride     int
47	Data       []complex64
48}
49
50// Triangular represents a triangular matrix using the conventional storage scheme.
51type Triangular struct {
52	N      int
53	Stride int
54	Data   []complex64
55	Uplo   blas.Uplo
56	Diag   blas.Diag
57}
58
59// TriangularBand represents a triangular matrix using the band storage scheme.
60type TriangularBand struct {
61	N, K   int
62	Stride int
63	Data   []complex64
64	Uplo   blas.Uplo
65	Diag   blas.Diag
66}
67
68// TriangularPacked represents a triangular matrix using the packed storage scheme.
69type TriangularPacked struct {
70	N    int
71	Data []complex64
72	Uplo blas.Uplo
73	Diag blas.Diag
74}
75
76// Symmetric represents a symmetric matrix using the conventional storage scheme.
77type Symmetric struct {
78	N      int
79	Stride int
80	Data   []complex64
81	Uplo   blas.Uplo
82}
83
84// SymmetricBand represents a symmetric matrix using the band storage scheme.
85type SymmetricBand struct {
86	N, K   int
87	Stride int
88	Data   []complex64
89	Uplo   blas.Uplo
90}
91
92// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
93type SymmetricPacked struct {
94	N    int
95	Data []complex64
96	Uplo blas.Uplo
97}
98
99// Hermitian represents an Hermitian matrix using the conventional storage scheme.
100type Hermitian Symmetric
101
102// HermitianBand represents an Hermitian matrix using the band storage scheme.
103type HermitianBand SymmetricBand
104
105// HermitianPacked represents an Hermitian matrix using the packed storage scheme.
106type HermitianPacked SymmetricPacked
107
108// Level 1
109
110const negInc = "cblas64: negative vector increment"
111
112// Dotu computes the dot product of the two vectors without
113// complex conjugation:
114//  x^T * y
115func Dotu(n int, x, y Vector) complex64 {
116	return cblas64.Cdotu(n, x.Data, x.Inc, y.Data, y.Inc)
117}
118
119// Dotc computes the dot product of the two vectors with
120// complex conjugation:
121//  x^H * y.
122func Dotc(n int, x, y Vector) complex64 {
123	return cblas64.Cdotc(n, x.Data, x.Inc, y.Data, y.Inc)
124}
125
126// Nrm2 computes the Euclidean norm of the vector x:
127//  sqrt(\sum_i x[i] * x[i]).
128//
129// Nrm2 will panic if the vector increment is negative.
130func Nrm2(n int, x Vector) float32 {
131	if x.Inc < 0 {
132		panic(negInc)
133	}
134	return cblas64.Scnrm2(n, x.Data, x.Inc)
135}
136
137// Asum computes the sum of magnitudes of the real and imaginary parts of
138// elements of the vector x:
139//  \sum_i (|Re x[i]| + |Im x[i]|).
140//
141// Asum will panic if the vector increment is negative.
142func Asum(n int, x Vector) float32 {
143	if x.Inc < 0 {
144		panic(negInc)
145	}
146	return cblas64.Scasum(n, x.Data, x.Inc)
147}
148
149// Iamax returns the index of an element of x with the largest sum of
150// magnitudes of the real and imaginary parts (|Re x[i]|+|Im x[i]|).
151// If there are multiple such indices, the earliest is returned.
152//
153// Iamax returns -1 if n == 0.
154//
155// Iamax will panic if the vector increment is negative.
156func Iamax(n int, x Vector) int {
157	if x.Inc < 0 {
158		panic(negInc)
159	}
160	return cblas64.Icamax(n, x.Data, x.Inc)
161}
162
163// Swap exchanges the elements of two vectors:
164//  x[i], y[i] = y[i], x[i] for all i.
165func Swap(n int, x, y Vector) {
166	cblas64.Cswap(n, x.Data, x.Inc, y.Data, y.Inc)
167}
168
169// Copy copies the elements of x into the elements of y:
170//  y[i] = x[i] for all i.
171func Copy(n int, x, y Vector) {
172	cblas64.Ccopy(n, x.Data, x.Inc, y.Data, y.Inc)
173}
174
175// Axpy computes
176//  y = alpha * x + y,
177// where x and y are vectors, and alpha is a scalar.
178func Axpy(n int, alpha complex64, x, y Vector) {
179	cblas64.Caxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
180}
181
182// Scal computes
183//  x = alpha * x,
184// where x is a vector, and alpha is a scalar.
185//
186// Scal will panic if the vector increment is negative.
187func Scal(n int, alpha complex64, x Vector) {
188	if x.Inc < 0 {
189		panic(negInc)
190	}
191	cblas64.Cscal(n, alpha, x.Data, x.Inc)
192}
193
194// Dscal computes
195//  x = alpha * x,
196// where x is a vector, and alpha is a real scalar.
197//
198// Dscal will panic if the vector increment is negative.
199func Dscal(n int, alpha float32, x Vector) {
200	if x.Inc < 0 {
201		panic(negInc)
202	}
203	cblas64.Csscal(n, alpha, x.Data, x.Inc)
204}
205
206// Level 2
207
208// Gemv computes
209//  y = alpha * A * x + beta * y,   if t == blas.NoTrans,
210//  y = alpha * A^T * x + beta * y, if t == blas.Trans,
211//  y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
212// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are
213// scalars.
214func Gemv(t blas.Transpose, alpha complex64, a General, x Vector, beta complex64, y Vector) {
215	cblas64.Cgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
216}
217
218// Gbmv computes
219//  y = alpha * A * x + beta * y,   if t == blas.NoTrans,
220//  y = alpha * A^T * x + beta * y, if t == blas.Trans,
221//  y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
222// where A is an m×n band matrix, x and y are vectors, and alpha and beta are
223// scalars.
224func Gbmv(t blas.Transpose, alpha complex64, a Band, x Vector, beta complex64, y Vector) {
225	cblas64.Cgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
226}
227
228// Trmv computes
229//  x = A * x,   if t == blas.NoTrans,
230//  x = A^T * x, if t == blas.Trans,
231//  x = A^H * x, if t == blas.ConjTrans,
232// where A is an n×n triangular matrix, and x is a vector.
233func Trmv(t blas.Transpose, a Triangular, x Vector) {
234	cblas64.Ctrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
235}
236
237// Tbmv computes
238//  x = A * x,   if t == blas.NoTrans,
239//  x = A^T * x, if t == blas.Trans,
240//  x = A^H * x, if t == blas.ConjTrans,
241// where A is an n×n triangular band matrix, and x is a vector.
242func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
243	cblas64.Ctbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
244}
245
246// Tpmv computes
247//  x = A * x,   if t == blas.NoTrans,
248//  x = A^T * x, if t == blas.Trans,
249//  x = A^H * x, if t == blas.ConjTrans,
250// where A is an n×n triangular matrix in packed format, and x is a vector.
251func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
252	cblas64.Ctpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
253}
254
255// Trsv solves
256//  A * x = b,   if t == blas.NoTrans,
257//  A^T * x = b, if t == blas.Trans,
258//  A^H * x = b, if t == blas.ConjTrans,
259// where A is an n×n triangular matrix and x is a vector.
260//
261// At entry to the function, x contains the values of b, and the result is
262// stored in-place into x.
263//
264// No test for singularity or near-singularity is included in this
265// routine. Such tests must be performed before calling this routine.
266func Trsv(t blas.Transpose, a Triangular, x Vector) {
267	cblas64.Ctrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
268}
269
270// Tbsv solves
271//  A * x = b,   if t == blas.NoTrans,
272//  A^T * x = b, if t == blas.Trans,
273//  A^H * x = b, if t == blas.ConjTrans,
274// where A is an n×n triangular band matrix, and x is a vector.
275//
276// At entry to the function, x contains the values of b, and the result is
277// stored in-place into x.
278//
279// No test for singularity or near-singularity is included in this
280// routine. Such tests must be performed before calling this routine.
281func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
282	cblas64.Ctbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
283}
284
285// Tpsv solves
286//  A * x = b,   if t == blas.NoTrans,
287//  A^T * x = b, if t == blas.Trans,
288//  A^H * x = b, if t == blas.ConjTrans,
289// where A is an n×n triangular matrix in packed format and x is a vector.
290//
291// At entry to the function, x contains the values of b, and the result is
292// stored in-place into x.
293//
294// No test for singularity or near-singularity is included in this
295// routine. Such tests must be performed before calling this routine.
296func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
297	cblas64.Ctpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
298}
299
300// Hemv computes
301//  y = alpha * A * x + beta * y,
302// where A is an n×n Hermitian matrix, x and y are vectors, and alpha and
303// beta are scalars.
304func Hemv(alpha complex64, a Hermitian, x Vector, beta complex64, y Vector) {
305	cblas64.Chemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
306}
307
308// Hbmv performs
309//  y = alpha * A * x + beta * y,
310// where A is an n×n Hermitian band matrix, x and y are vectors, and alpha
311// and beta are scalars.
312func Hbmv(alpha complex64, a HermitianBand, x Vector, beta complex64, y Vector) {
313	cblas64.Chbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
314}
315
316// Hpmv performs
317//  y = alpha * A * x + beta * y,
318// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
319// and alpha and beta are scalars.
320func Hpmv(alpha complex64, a HermitianPacked, x Vector, beta complex64, y Vector) {
321	cblas64.Chpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
322}
323
324// Geru performs a rank-1 update
325//  A += alpha * x * y^T,
326// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
327func Geru(alpha complex64, x, y Vector, a General) {
328	cblas64.Cgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
329}
330
331// Gerc performs a rank-1 update
332//  A += alpha * x * y^H,
333// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
334func Gerc(alpha complex64, x, y Vector, a General) {
335	cblas64.Cgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
336}
337
338// Her performs a rank-1 update
339//  A += alpha * x * y^T,
340// where A is an m×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
341func Her(alpha float32, x Vector, a Hermitian) {
342	cblas64.Cher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
343}
344
345// Hpr performs a rank-1 update
346//  A += alpha * x * x^H,
347// where A is an n×n Hermitian matrix in packed format, x is a vector, and
348// alpha is a scalar.
349func Hpr(alpha float32, x Vector, a HermitianPacked) {
350	cblas64.Chpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
351}
352
353// Her2 performs a rank-2 update
354//  A += alpha * x * y^H + conj(alpha) * y * x^H,
355// where A is an n×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
356func Her2(alpha complex64, x, y Vector, a Hermitian) {
357	cblas64.Cher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
358}
359
360// Hpr2 performs a rank-2 update
361//  A += alpha * x * y^H + conj(alpha) * y * x^H,
362// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
363// and alpha is a scalar.
364func Hpr2(alpha complex64, x, y Vector, a HermitianPacked) {
365	cblas64.Chpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
366}
367
368// Level 3
369
370// Gemm computes
371//  C = alpha * A * B + beta * C,
372// where A, B, and C are dense matrices, and alpha and beta are scalars.
373// tA and tB specify whether A or B are transposed or conjugated.
374func Gemm(tA, tB blas.Transpose, alpha complex64, a, b General, beta complex64, c General) {
375	var m, n, k int
376	if tA == blas.NoTrans {
377		m, k = a.Rows, a.Cols
378	} else {
379		m, k = a.Cols, a.Rows
380	}
381	if tB == blas.NoTrans {
382		n = b.Cols
383	} else {
384		n = b.Rows
385	}
386	cblas64.Cgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
387}
388
389// Symm performs
390//  C = alpha * A * B + beta * C, if s == blas.Left,
391//  C = alpha * B * A + beta * C, if s == blas.Right,
392// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
393// alpha and beta are scalars.
394func Symm(s blas.Side, alpha complex64, a Symmetric, b General, beta complex64, c General) {
395	var m, n int
396	if s == blas.Left {
397		m, n = a.N, b.Cols
398	} else {
399		m, n = b.Rows, a.N
400	}
401	cblas64.Csymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
402}
403
404// Syrk performs a symmetric rank-k update
405//  C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
406//  C = alpha * A^T * A + beta * C, if t == blas.Trans,
407// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans
408// and a k×n matrix otherwise, and alpha and beta are scalars.
409func Syrk(t blas.Transpose, alpha complex64, a General, beta complex64, c Symmetric) {
410	var n, k int
411	if t == blas.NoTrans {
412		n, k = a.Rows, a.Cols
413	} else {
414		n, k = a.Cols, a.Rows
415	}
416	cblas64.Csyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
417}
418
419// Syr2k performs a symmetric rank-2k update
420//  C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
421//  C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans,
422// where C is an n×n symmetric matrix, A and B are n×k matrices if
423// t == blas.NoTrans and k×n otherwise, and alpha and beta are scalars.
424func Syr2k(t blas.Transpose, alpha complex64, a, b General, beta complex64, c Symmetric) {
425	var n, k int
426	if t == blas.NoTrans {
427		n, k = a.Rows, a.Cols
428	} else {
429		n, k = a.Cols, a.Rows
430	}
431	cblas64.Csyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
432}
433
434// Trmm performs
435//  B = alpha * A * B,   if tA == blas.NoTrans and s == blas.Left,
436//  B = alpha * A^T * B, if tA == blas.Trans and s == blas.Left,
437//  B = alpha * A^H * B, if tA == blas.ConjTrans and s == blas.Left,
438//  B = alpha * B * A,   if tA == blas.NoTrans and s == blas.Right,
439//  B = alpha * B * A^T, if tA == blas.Trans and s == blas.Right,
440//  B = alpha * B * A^H, if tA == blas.ConjTrans and s == blas.Right,
441// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
442// a scalar.
443func Trmm(s blas.Side, tA blas.Transpose, alpha complex64, a Triangular, b General) {
444	cblas64.Ctrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
445}
446
447// Trsm solves
448//  A * X = alpha * B,   if tA == blas.NoTrans and s == blas.Left,
449//  A^T * X = alpha * B, if tA == blas.Trans and s == blas.Left,
450//  A^H * X = alpha * B, if tA == blas.ConjTrans and s == blas.Left,
451//  X * A = alpha * B,   if tA == blas.NoTrans and s == blas.Right,
452//  X * A^T = alpha * B, if tA == blas.Trans and s == blas.Right,
453//  X * A^H = alpha * B, if tA == blas.ConjTrans and s == blas.Right,
454// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
455// alpha is a scalar.
456//
457// At entry to the function, b contains the values of B, and the result is
458// stored in-place into b.
459//
460// No check is made that A is invertible.
461func Trsm(s blas.Side, tA blas.Transpose, alpha complex64, a Triangular, b General) {
462	cblas64.Ctrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
463}
464
465// Hemm performs
466//  C = alpha * A * B + beta * C, if s == blas.Left,
467//  C = alpha * B * A + beta * C, if s == blas.Right,
468// where A is an n×n or m×m Hermitian matrix, B and C are m×n matrices, and
469// alpha and beta are scalars.
470func Hemm(s blas.Side, alpha complex64, a Hermitian, b General, beta complex64, c General) {
471	var m, n int
472	if s == blas.Left {
473		m, n = a.N, b.Cols
474	} else {
475		m, n = b.Rows, a.N
476	}
477	cblas64.Chemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
478}
479
480// Herk performs the Hermitian rank-k update
481//  C = alpha * A * A^H + beta*C, if t == blas.NoTrans,
482//  C = alpha * A^H * A + beta*C, if t == blas.ConjTrans,
483// where C is an n×n Hermitian matrix, A is an n×k matrix if t == blas.NoTrans
484// and a k×n matrix otherwise, and alpha and beta are scalars.
485func Herk(t blas.Transpose, alpha float32, a General, beta float32, c Hermitian) {
486	var n, k int
487	if t == blas.NoTrans {
488		n, k = a.Rows, a.Cols
489	} else {
490		n, k = a.Cols, a.Rows
491	}
492	cblas64.Cherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
493}
494
495// Her2k performs the Hermitian rank-2k update
496//  C = alpha * A * B^H + conj(alpha) * B * A^H + beta * C, if t == blas.NoTrans,
497//  C = alpha * A^H * B + conj(alpha) * B^H * A + beta * C, if t == blas.ConjTrans,
498// where C is an n×n Hermitian matrix, A and B are n×k matrices if t == NoTrans
499// and k×n matrices otherwise, and alpha and beta are scalars.
500func Her2k(t blas.Transpose, alpha complex64, a, b General, beta float32, c Hermitian) {
501	var n, k int
502	if t == blas.NoTrans {
503		n, k = a.Rows, a.Cols
504	} else {
505		n, k = a.Cols, a.Rows
506	}
507	cblas64.Cher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
508}
509