1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cblas128
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/gonum"
10)
11
12var cblas128 blas.Complex128 = gonum.Implementation{}
13
14// Use sets the BLAS complex128 implementation to be used by subsequent BLAS calls.
15// The default implementation is
16// gonum.org/v1/gonum/blas/gonum.Implementation.
17func Use(b blas.Complex128) {
18	cblas128 = b
19}
20
21// Implementation returns the current BLAS complex128 implementation.
22//
23// Implementation allows direct calls to the current the BLAS complex128 implementation
24// giving finer control of parameters.
25func Implementation() blas.Complex128 {
26	return cblas128
27}
28
29// Vector represents a vector with an associated element increment.
30type Vector struct {
31	N    int
32	Inc  int
33	Data []complex128
34}
35
36// General represents a matrix using the conventional storage scheme.
37type General struct {
38	Rows, Cols int
39	Stride     int
40	Data       []complex128
41}
42
43// Band represents a band matrix using the band storage scheme.
44type Band struct {
45	Rows, Cols int
46	KL, KU     int
47	Stride     int
48	Data       []complex128
49}
50
51// Triangular represents a triangular matrix using the conventional storage scheme.
52type Triangular struct {
53	N      int
54	Stride int
55	Data   []complex128
56	Uplo   blas.Uplo
57	Diag   blas.Diag
58}
59
60// TriangularBand represents a triangular matrix using the band storage scheme.
61type TriangularBand struct {
62	N, K   int
63	Stride int
64	Data   []complex128
65	Uplo   blas.Uplo
66	Diag   blas.Diag
67}
68
69// TriangularPacked represents a triangular matrix using the packed storage scheme.
70type TriangularPacked struct {
71	N    int
72	Data []complex128
73	Uplo blas.Uplo
74	Diag blas.Diag
75}
76
77// Symmetric represents a symmetric matrix using the conventional storage scheme.
78type Symmetric struct {
79	N      int
80	Stride int
81	Data   []complex128
82	Uplo   blas.Uplo
83}
84
85// SymmetricBand represents a symmetric matrix using the band storage scheme.
86type SymmetricBand struct {
87	N, K   int
88	Stride int
89	Data   []complex128
90	Uplo   blas.Uplo
91}
92
93// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
94type SymmetricPacked struct {
95	N    int
96	Data []complex128
97	Uplo blas.Uplo
98}
99
100// Hermitian represents an Hermitian matrix using the conventional storage scheme.
101type Hermitian Symmetric
102
103// HermitianBand represents an Hermitian matrix using the band storage scheme.
104type HermitianBand SymmetricBand
105
106// HermitianPacked represents an Hermitian matrix using the packed storage scheme.
107type HermitianPacked SymmetricPacked
108
109// Level 1
110
111const (
112	negInc    = "cblas128: negative vector increment"
113	badLength = "cblas128: vector length mismatch"
114)
115
116// Dotu computes the dot product of the two vectors without
117// complex conjugation:
118//  xᵀ * y.
119// Dotu will panic if the lengths of x and y do not match.
120func Dotu(x, y Vector) complex128 {
121	if x.N != y.N {
122		panic(badLength)
123	}
124	return cblas128.Zdotu(x.N, x.Data, x.Inc, y.Data, y.Inc)
125}
126
127// Dotc computes the dot product of the two vectors with
128// complex conjugation:
129//  xᴴ * y.
130// Dotc will panic if the lengths of x and y do not match.
131func Dotc(x, y Vector) complex128 {
132	if x.N != y.N {
133		panic(badLength)
134	}
135	return cblas128.Zdotc(x.N, x.Data, x.Inc, y.Data, y.Inc)
136}
137
138// Nrm2 computes the Euclidean norm of the vector x:
139//  sqrt(\sum_i x[i] * x[i]).
140//
141// Nrm2 will panic if the vector increment is negative.
142func Nrm2(x Vector) float64 {
143	if x.Inc < 0 {
144		panic(negInc)
145	}
146	return cblas128.Dznrm2(x.N, x.Data, x.Inc)
147}
148
149// Asum computes the sum of magnitudes of the real and imaginary parts of
150// elements of the vector x:
151//  \sum_i (|Re x[i]| + |Im x[i]|).
152//
153// Asum will panic if the vector increment is negative.
154func Asum(x Vector) float64 {
155	if x.Inc < 0 {
156		panic(negInc)
157	}
158	return cblas128.Dzasum(x.N, x.Data, x.Inc)
159}
160
161// Iamax returns the index of an element of x with the largest sum of
162// magnitudes of the real and imaginary parts (|Re x[i]|+|Im x[i]|).
163// If there are multiple such indices, the earliest is returned.
164//
165// Iamax returns -1 if n == 0.
166//
167// Iamax will panic if the vector increment is negative.
168func Iamax(x Vector) int {
169	if x.Inc < 0 {
170		panic(negInc)
171	}
172	return cblas128.Izamax(x.N, x.Data, x.Inc)
173}
174
175// Swap exchanges the elements of two vectors:
176//  x[i], y[i] = y[i], x[i] for all i.
177// Swap will panic if the lengths of x and y do not match.
178func Swap(x, y Vector) {
179	if x.N != y.N {
180		panic(badLength)
181	}
182	cblas128.Zswap(x.N, x.Data, x.Inc, y.Data, y.Inc)
183}
184
185// Copy copies the elements of x into the elements of y:
186//  y[i] = x[i] for all i.
187// Copy will panic if the lengths of x and y do not match.
188func Copy(x, y Vector) {
189	if x.N != y.N {
190		panic(badLength)
191	}
192	cblas128.Zcopy(x.N, x.Data, x.Inc, y.Data, y.Inc)
193}
194
195// Axpy computes
196//  y = alpha * x + y,
197// where x and y are vectors, and alpha is a scalar.
198// Axpy will panic if the lengths of x and y do not match.
199func Axpy(alpha complex128, x, y Vector) {
200	if x.N != y.N {
201		panic(badLength)
202	}
203	cblas128.Zaxpy(x.N, alpha, x.Data, x.Inc, y.Data, y.Inc)
204}
205
206// Scal computes
207//  x = alpha * x,
208// where x is a vector, and alpha is a scalar.
209//
210// Scal will panic if the vector increment is negative.
211func Scal(alpha complex128, x Vector) {
212	if x.Inc < 0 {
213		panic(negInc)
214	}
215	cblas128.Zscal(x.N, alpha, x.Data, x.Inc)
216}
217
218// Dscal computes
219//  x = alpha * x,
220// where x is a vector, and alpha is a real scalar.
221//
222// Dscal will panic if the vector increment is negative.
223func Dscal(alpha float64, x Vector) {
224	if x.Inc < 0 {
225		panic(negInc)
226	}
227	cblas128.Zdscal(x.N, alpha, x.Data, x.Inc)
228}
229
230// Level 2
231
232// Gemv computes
233//  y = alpha * A * x + beta * y   if t == blas.NoTrans,
234//  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans,
235//  y = alpha * Aᴴ * x + beta * y  if t == blas.ConjTrans,
236// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are
237// scalars.
238func Gemv(t blas.Transpose, alpha complex128, a General, x Vector, beta complex128, y Vector) {
239	cblas128.Zgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
240}
241
242// Gbmv computes
243//  y = alpha * A * x + beta * y   if t == blas.NoTrans,
244//  y = alpha * Aᵀ * x + beta * y  if t == blas.Trans,
245//  y = alpha * Aᴴ * x + beta * y  if t == blas.ConjTrans,
246// where A is an m×n band matrix, x and y are vectors, and alpha and beta are
247// scalars.
248func Gbmv(t blas.Transpose, alpha complex128, a Band, x Vector, beta complex128, y Vector) {
249	cblas128.Zgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
250}
251
252// Trmv computes
253//  x = A * x   if t == blas.NoTrans,
254//  x = Aᵀ * x  if t == blas.Trans,
255//  x = Aᴴ * x  if t == blas.ConjTrans,
256// where A is an n×n triangular matrix, and x is a vector.
257func Trmv(t blas.Transpose, a Triangular, x Vector) {
258	cblas128.Ztrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
259}
260
261// Tbmv computes
262//  x = A * x   if t == blas.NoTrans,
263//  x = Aᵀ * x  if t == blas.Trans,
264//  x = Aᴴ * x  if t == blas.ConjTrans,
265// where A is an n×n triangular band matrix, and x is a vector.
266func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
267	cblas128.Ztbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
268}
269
270// Tpmv computes
271//  x = A * x   if t == blas.NoTrans,
272//  x = Aᵀ * x  if t == blas.Trans,
273//  x = Aᴴ * x  if t == blas.ConjTrans,
274// where A is an n×n triangular matrix in packed format, and x is a vector.
275func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
276	cblas128.Ztpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
277}
278
279// Trsv solves
280//  A * x = b   if t == blas.NoTrans,
281//  Aᵀ * x = b  if t == blas.Trans,
282//  Aᴴ * x = b  if t == blas.ConjTrans,
283// where A is an n×n triangular matrix and x is a vector.
284//
285// At entry to the function, x contains the values of b, and the result is
286// stored in-place into x.
287//
288// No test for singularity or near-singularity is included in this
289// routine. Such tests must be performed before calling this routine.
290func Trsv(t blas.Transpose, a Triangular, x Vector) {
291	cblas128.Ztrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
292}
293
294// Tbsv solves
295//  A * x = b   if t == blas.NoTrans,
296//  Aᵀ * x = b  if t == blas.Trans,
297//  Aᴴ * x = b  if t == blas.ConjTrans,
298// where A is an n×n triangular band matrix, and x is a vector.
299//
300// At entry to the function, x contains the values of b, and the result is
301// stored in-place into x.
302//
303// No test for singularity or near-singularity is included in this
304// routine. Such tests must be performed before calling this routine.
305func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
306	cblas128.Ztbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
307}
308
309// Tpsv solves
310//  A * x = b   if t == blas.NoTrans,
311//  Aᵀ * x = b  if t == blas.Trans,
312//  Aᴴ * x = b  if t == blas.ConjTrans,
313// where A is an n×n triangular matrix in packed format and x is a vector.
314//
315// At entry to the function, x contains the values of b, and the result is
316// stored in-place into x.
317//
318// No test for singularity or near-singularity is included in this
319// routine. Such tests must be performed before calling this routine.
320func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
321	cblas128.Ztpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
322}
323
324// Hemv computes
325//  y = alpha * A * x + beta * y,
326// where A is an n×n Hermitian matrix, x and y are vectors, and alpha and
327// beta are scalars.
328func Hemv(alpha complex128, a Hermitian, x Vector, beta complex128, y Vector) {
329	cblas128.Zhemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
330}
331
332// Hbmv performs
333//  y = alpha * A * x + beta * y,
334// where A is an n×n Hermitian band matrix, x and y are vectors, and alpha
335// and beta are scalars.
336func Hbmv(alpha complex128, a HermitianBand, x Vector, beta complex128, y Vector) {
337	cblas128.Zhbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
338}
339
340// Hpmv performs
341//  y = alpha * A * x + beta * y,
342// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
343// and alpha and beta are scalars.
344func Hpmv(alpha complex128, a HermitianPacked, x Vector, beta complex128, y Vector) {
345	cblas128.Zhpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
346}
347
348// Geru performs a rank-1 update
349//  A += alpha * x * yᵀ,
350// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
351func Geru(alpha complex128, x, y Vector, a General) {
352	cblas128.Zgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
353}
354
355// Gerc performs a rank-1 update
356//  A += alpha * x * yᴴ,
357// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
358func Gerc(alpha complex128, x, y Vector, a General) {
359	cblas128.Zgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
360}
361
362// Her performs a rank-1 update
363//  A += alpha * x * yᵀ,
364// where A is an m×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
365func Her(alpha float64, x Vector, a Hermitian) {
366	cblas128.Zher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
367}
368
369// Hpr performs a rank-1 update
370//  A += alpha * x * xᴴ,
371// where A is an n×n Hermitian matrix in packed format, x is a vector, and
372// alpha is a scalar.
373func Hpr(alpha float64, x Vector, a HermitianPacked) {
374	cblas128.Zhpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
375}
376
377// Her2 performs a rank-2 update
378//  A += alpha * x * yᴴ + conj(alpha) * y * xᴴ,
379// where A is an n×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
380func Her2(alpha complex128, x, y Vector, a Hermitian) {
381	cblas128.Zher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
382}
383
384// Hpr2 performs a rank-2 update
385//  A += alpha * x * yᴴ + conj(alpha) * y * xᴴ,
386// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
387// and alpha is a scalar.
388func Hpr2(alpha complex128, x, y Vector, a HermitianPacked) {
389	cblas128.Zhpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
390}
391
392// Level 3
393
394// Gemm computes
395//  C = alpha * A * B + beta * C,
396// where A, B, and C are dense matrices, and alpha and beta are scalars.
397// tA and tB specify whether A or B are transposed or conjugated.
398func Gemm(tA, tB blas.Transpose, alpha complex128, a, b General, beta complex128, c General) {
399	var m, n, k int
400	if tA == blas.NoTrans {
401		m, k = a.Rows, a.Cols
402	} else {
403		m, k = a.Cols, a.Rows
404	}
405	if tB == blas.NoTrans {
406		n = b.Cols
407	} else {
408		n = b.Rows
409	}
410	cblas128.Zgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
411}
412
413// Symm performs
414//  C = alpha * A * B + beta * C  if s == blas.Left,
415//  C = alpha * B * A + beta * C  if s == blas.Right,
416// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
417// alpha and beta are scalars.
418func Symm(s blas.Side, alpha complex128, a Symmetric, b General, beta complex128, c General) {
419	var m, n int
420	if s == blas.Left {
421		m, n = a.N, b.Cols
422	} else {
423		m, n = b.Rows, a.N
424	}
425	cblas128.Zsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
426}
427
428// Syrk performs a symmetric rank-k update
429//  C = alpha * A * Aᵀ + beta * C  if t == blas.NoTrans,
430//  C = alpha * Aᵀ * A + beta * C  if t == blas.Trans,
431// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans
432// and a k×n matrix otherwise, and alpha and beta are scalars.
433func Syrk(t blas.Transpose, alpha complex128, a General, beta complex128, c Symmetric) {
434	var n, k int
435	if t == blas.NoTrans {
436		n, k = a.Rows, a.Cols
437	} else {
438		n, k = a.Cols, a.Rows
439	}
440	cblas128.Zsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
441}
442
443// Syr2k performs a symmetric rank-2k update
444//  C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C  if t == blas.NoTrans,
445//  C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C  if t == blas.Trans,
446// where C is an n×n symmetric matrix, A and B are n×k matrices if
447// t == blas.NoTrans and k×n otherwise, and alpha and beta are scalars.
448func Syr2k(t blas.Transpose, alpha complex128, a, b General, beta complex128, c Symmetric) {
449	var n, k int
450	if t == blas.NoTrans {
451		n, k = a.Rows, a.Cols
452	} else {
453		n, k = a.Cols, a.Rows
454	}
455	cblas128.Zsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
456}
457
458// Trmm performs
459//  B = alpha * A * B   if tA == blas.NoTrans and s == blas.Left,
460//  B = alpha * Aᵀ * B  if tA == blas.Trans and s == blas.Left,
461//  B = alpha * Aᴴ * B  if tA == blas.ConjTrans and s == blas.Left,
462//  B = alpha * B * A   if tA == blas.NoTrans and s == blas.Right,
463//  B = alpha * B * Aᵀ  if tA == blas.Trans and s == blas.Right,
464//  B = alpha * B * Aᴴ  if tA == blas.ConjTrans and s == blas.Right,
465// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
466// a scalar.
467func Trmm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
468	cblas128.Ztrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
469}
470
471// Trsm solves
472//  A * X = alpha * B   if tA == blas.NoTrans and s == blas.Left,
473//  Aᵀ * X = alpha * B  if tA == blas.Trans and s == blas.Left,
474//  Aᴴ * X = alpha * B  if tA == blas.ConjTrans and s == blas.Left,
475//  X * A = alpha * B   if tA == blas.NoTrans and s == blas.Right,
476//  X * Aᵀ = alpha * B  if tA == blas.Trans and s == blas.Right,
477//  X * Aᴴ = alpha * B  if tA == blas.ConjTrans and s == blas.Right,
478// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
479// alpha is a scalar.
480//
481// At entry to the function, b contains the values of B, and the result is
482// stored in-place into b.
483//
484// No check is made that A is invertible.
485func Trsm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
486	cblas128.Ztrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
487}
488
489// Hemm performs
490//  C = alpha * A * B + beta * C  if s == blas.Left,
491//  C = alpha * B * A + beta * C  if s == blas.Right,
492// where A is an n×n or m×m Hermitian matrix, B and C are m×n matrices, and
493// alpha and beta are scalars.
494func Hemm(s blas.Side, alpha complex128, a Hermitian, b General, beta complex128, c General) {
495	var m, n int
496	if s == blas.Left {
497		m, n = a.N, b.Cols
498	} else {
499		m, n = b.Rows, a.N
500	}
501	cblas128.Zhemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
502}
503
504// Herk performs the Hermitian rank-k update
505//  C = alpha * A * Aᴴ + beta*C  if t == blas.NoTrans,
506//  C = alpha * Aᴴ * A + beta*C  if t == blas.ConjTrans,
507// where C is an n×n Hermitian matrix, A is an n×k matrix if t == blas.NoTrans
508// and a k×n matrix otherwise, and alpha and beta are scalars.
509func Herk(t blas.Transpose, alpha float64, a General, beta float64, c Hermitian) {
510	var n, k int
511	if t == blas.NoTrans {
512		n, k = a.Rows, a.Cols
513	} else {
514		n, k = a.Cols, a.Rows
515	}
516	cblas128.Zherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
517}
518
519// Her2k performs the Hermitian rank-2k update
520//  C = alpha * A * Bᴴ + conj(alpha) * B * Aᴴ + beta * C  if t == blas.NoTrans,
521//  C = alpha * Aᴴ * B + conj(alpha) * Bᴴ * A + beta * C  if t == blas.ConjTrans,
522// where C is an n×n Hermitian matrix, A and B are n×k matrices if t == NoTrans
523// and k×n matrices otherwise, and alpha and beta are scalars.
524func Her2k(t blas.Transpose, alpha complex128, a, b General, beta float64, c Hermitian) {
525	var n, k int
526	if t == blas.NoTrans {
527		n, k = a.Rows, a.Cols
528	} else {
529		n, k = a.Cols, a.Rows
530	}
531	cblas128.Zher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
532}
533