1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testlapack
6
7import (
8	"fmt"
9	"math"
10	"math/cmplx"
11	"testing"
12
13	"golang.org/x/exp/rand"
14
15	"gonum.org/v1/gonum/blas"
16	"gonum.org/v1/gonum/blas/blas64"
17	"gonum.org/v1/gonum/lapack"
18)
19
20const (
21	// dlamchE is the machine epsilon. For IEEE this is 2^{-53}.
22	dlamchE = 0x1p-53
23	dlamchB = 2
24	dlamchP = dlamchB * dlamchE
25	// dlamchS is the smallest normal number. For IEEE this is 2^{-1022}.
26	dlamchS = 0x1p-1022
27)
28
29func max(a, b int) int {
30	if a > b {
31		return a
32	}
33	return b
34}
35
36func min(a, b int) int {
37	if a < b {
38		return a
39	}
40	return b
41}
42
43// worklen describes how much workspace a test should use.
44type worklen int
45
46const (
47	minimumWork worklen = iota
48	mediumWork
49	optimumWork
50)
51
52func (wl worklen) String() string {
53	switch wl {
54	case minimumWork:
55		return "minimum"
56	case mediumWork:
57		return "medium"
58	case optimumWork:
59		return "optimum"
60	}
61	return ""
62}
63
64func normToString(norm lapack.MatrixNorm) string {
65	switch norm {
66	case lapack.MaxAbs:
67		return "MaxAbs"
68	case lapack.MaxRowSum:
69		return "MaxRowSum"
70	case lapack.MaxColumnSum:
71		return "MaxColSum"
72	case lapack.Frobenius:
73		return "Frobenius"
74	default:
75		panic("invalid norm")
76	}
77}
78
79func uploToString(uplo blas.Uplo) string {
80	switch uplo {
81	case blas.Lower:
82		return "Lower"
83	case blas.Upper:
84		return "Upper"
85	default:
86		panic("invalid uplo")
87	}
88}
89
90func diagToString(diag blas.Diag) string {
91	switch diag {
92	case blas.NonUnit:
93		return "NonUnit"
94	case blas.Unit:
95		return "Unit"
96	default:
97		panic("invalid diag")
98	}
99}
100
101func sideToString(side blas.Side) string {
102	switch side {
103	case blas.Left:
104		return "Left"
105	case blas.Right:
106		return "Right"
107	default:
108		panic("invalid side")
109	}
110}
111
112func transToString(trans blas.Transpose) string {
113	switch trans {
114	case blas.NoTrans:
115		return "NoTrans"
116	case blas.Trans:
117		return "Trans"
118	case blas.ConjTrans:
119		return "ConjTrans"
120	default:
121		panic("invalid trans")
122	}
123}
124
125// nanSlice allocates a new slice of length n filled with NaN.
126func nanSlice(n int) []float64 {
127	s := make([]float64, n)
128	for i := range s {
129		s[i] = math.NaN()
130	}
131	return s
132}
133
134// randomSlice allocates a new slice of length n filled with random values.
135func randomSlice(n int, rnd *rand.Rand) []float64 {
136	s := make([]float64, n)
137	for i := range s {
138		s[i] = rnd.NormFloat64()
139	}
140	return s
141}
142
143// nanGeneral allocates a new r×c general matrix filled with NaN values.
144func nanGeneral(r, c, stride int) blas64.General {
145	if r < 0 || c < 0 {
146		panic("bad matrix size")
147	}
148	if r == 0 || c == 0 {
149		return blas64.General{Stride: max(1, stride)}
150	}
151	if stride < c {
152		panic("bad stride")
153	}
154	return blas64.General{
155		Rows:   r,
156		Cols:   c,
157		Stride: stride,
158		Data:   nanSlice((r-1)*stride + c),
159	}
160}
161
162// randomGeneral allocates a new r×c general matrix filled with random
163// numbers. Out-of-range elements are filled with NaN values.
164func randomGeneral(r, c, stride int, rnd *rand.Rand) blas64.General {
165	ans := nanGeneral(r, c, stride)
166	for i := 0; i < r; i++ {
167		for j := 0; j < c; j++ {
168			ans.Data[i*ans.Stride+j] = rnd.NormFloat64()
169		}
170	}
171	return ans
172}
173
174// randomHessenberg allocates a new n×n Hessenberg matrix filled with zeros
175// under the first subdiagonal and with random numbers elsewhere. Out-of-range
176// elements are filled with NaN values.
177func randomHessenberg(n, stride int, rnd *rand.Rand) blas64.General {
178	ans := nanGeneral(n, n, stride)
179	for i := 0; i < n; i++ {
180		for j := 0; j < i-1; j++ {
181			ans.Data[i*ans.Stride+j] = 0
182		}
183		for j := max(0, i-1); j < n; j++ {
184			ans.Data[i*ans.Stride+j] = rnd.NormFloat64()
185		}
186	}
187	return ans
188}
189
190// randomSchurCanonical returns a random, general matrix in Schur canonical
191// form, that is, block upper triangular with 1×1 and 2×2 diagonal blocks where
192// each 2×2 diagonal block has its diagonal elements equal and its off-diagonal
193// elements of opposite sign. bad controls whether the returned matrix will have
194// zero or tiny eigenvalues.
195func randomSchurCanonical(n, stride int, bad bool, rnd *rand.Rand) (t blas64.General, wr, wi []float64) {
196	t = randomGeneral(n, n, stride, rnd)
197	// Zero out the lower triangle including the diagonal which will be set later.
198	for i := 0; i < t.Rows; i++ {
199		for j := 0; j <= i; j++ {
200			t.Data[i*t.Stride+j] = 0
201		}
202	}
203	// Randomly create 2×2 diagonal blocks.
204	for i := 0; i < t.Rows; {
205		a := rnd.NormFloat64()
206		if bad && rnd.Float64() < 0.5 {
207			if rnd.Float64() < 0.5 {
208				// A quarter of real parts of eigenvalues will be tiny.
209				a = dlamchS
210			} else {
211				// A quarter of them will be zero.
212				a = 0
213			}
214		}
215
216		// A half of eigenvalues will be real.
217		if rnd.Float64() < 0.5 || i == t.Rows-1 {
218			// Store 1×1 block at the diagonal of T.
219			t.Data[i*t.Stride+i] = a
220			wr = append(wr, a)
221			wi = append(wi, 0)
222			i++
223			continue
224		}
225
226		// Diagonal elements are equal.
227		d := a
228		// Element under the diagonal is "normal".
229		c := rnd.NormFloat64()
230		// Element above the diagonal cannot be zero.
231		var b float64
232		if bad && rnd.Float64() < 0.5 {
233			b = dlamchS
234		} else {
235			b = rnd.NormFloat64()
236		}
237		// Make sure off-diagonal elements are of opposite sign.
238		if math.Signbit(b) == math.Signbit(c) {
239			c *= -1
240		}
241
242		// Store 2×2 block at the diagonal of T.
243		t.Data[i*t.Stride+i], t.Data[i*t.Stride+i+1] = a, b
244		t.Data[(i+1)*t.Stride+i], t.Data[(i+1)*t.Stride+i+1] = c, d
245
246		wr = append(wr, a, a)
247		im := math.Sqrt(math.Abs(b)) * math.Sqrt(math.Abs(c))
248		wi = append(wi, im, -im)
249		i += 2
250	}
251	return t, wr, wi
252}
253
254// blockedUpperTriGeneral returns a normal random, general matrix in the form
255//
256//            c-k-l  k    l
257//  A =    k [  0   A12  A13 ] if r-k-l >= 0;
258//         l [  0    0   A23 ]
259//     r-k-l [  0    0    0  ]
260//
261//          c-k-l  k    l
262//  A =  k [  0   A12  A13 ] if r-k-l < 0;
263//     r-k [  0    0   A23 ]
264//
265// where the k×k matrix A12 and l×l matrix is non-singular
266// upper triangular. A23 is l×l upper triangular if r-k-l >= 0,
267// otherwise A23 is (r-k)×l upper trapezoidal.
268func blockedUpperTriGeneral(r, c, k, l, stride int, kblock bool, rnd *rand.Rand) blas64.General {
269	t := l
270	if kblock {
271		t += k
272	}
273	ans := zeros(r, c, stride)
274	for i := 0; i < min(r, t); i++ {
275		var v float64
276		for v == 0 {
277			v = rnd.NormFloat64()
278		}
279		ans.Data[i*ans.Stride+i+(c-t)] = v
280	}
281	for i := 0; i < min(r, t); i++ {
282		for j := i + (c - t) + 1; j < c; j++ {
283			ans.Data[i*ans.Stride+j] = rnd.NormFloat64()
284		}
285	}
286	return ans
287}
288
289// nanTriangular allocates a new r×c triangular matrix filled with NaN values.
290func nanTriangular(uplo blas.Uplo, n, stride int) blas64.Triangular {
291	if n < 0 {
292		panic("bad matrix size")
293	}
294	if n == 0 {
295		return blas64.Triangular{
296			Stride: max(1, stride),
297			Uplo:   uplo,
298			Diag:   blas.NonUnit,
299		}
300	}
301	if stride < n {
302		panic("bad stride")
303	}
304	return blas64.Triangular{
305		N:      n,
306		Stride: stride,
307		Data:   nanSlice((n-1)*stride + n),
308		Uplo:   uplo,
309		Diag:   blas.NonUnit,
310	}
311}
312
313// generalOutsideAllNaN returns whether all out-of-range elements have NaN
314// values.
315func generalOutsideAllNaN(a blas64.General) bool {
316	// Check after last column.
317	for i := 0; i < a.Rows-1; i++ {
318		for _, v := range a.Data[i*a.Stride+a.Cols : i*a.Stride+a.Stride] {
319			if !math.IsNaN(v) {
320				return false
321			}
322		}
323	}
324	// Check after last element.
325	last := (a.Rows-1)*a.Stride + a.Cols
326	if a.Rows == 0 || a.Cols == 0 {
327		last = 0
328	}
329	for _, v := range a.Data[last:] {
330		if !math.IsNaN(v) {
331			return false
332		}
333	}
334	return true
335}
336
337// triangularOutsideAllNaN returns whether all out-of-triangle elements have NaN
338// values.
339func triangularOutsideAllNaN(a blas64.Triangular) bool {
340	if a.Uplo == blas.Upper {
341		// Check below diagonal.
342		for i := 0; i < a.N; i++ {
343			for _, v := range a.Data[i*a.Stride : i*a.Stride+i] {
344				if !math.IsNaN(v) {
345					return false
346				}
347			}
348		}
349		// Check after last column.
350		for i := 0; i < a.N-1; i++ {
351			for _, v := range a.Data[i*a.Stride+a.N : i*a.Stride+a.Stride] {
352				if !math.IsNaN(v) {
353					return false
354				}
355			}
356		}
357	} else {
358		// Check above diagonal.
359		for i := 0; i < a.N-1; i++ {
360			for _, v := range a.Data[i*a.Stride+i+1 : i*a.Stride+a.Stride] {
361				if !math.IsNaN(v) {
362					return false
363				}
364			}
365		}
366	}
367	// Check after last element.
368	for _, v := range a.Data[max(0, a.N-1)*a.Stride+a.N:] {
369		if !math.IsNaN(v) {
370			return false
371		}
372	}
373	return true
374}
375
376// transposeGeneral returns a new general matrix that is the transpose of the
377// input. Nothing is done with data outside the {rows, cols} limit of the general.
378func transposeGeneral(a blas64.General) blas64.General {
379	ans := blas64.General{
380		Rows:   a.Cols,
381		Cols:   a.Rows,
382		Stride: a.Rows,
383		Data:   make([]float64, a.Cols*a.Rows),
384	}
385	for i := 0; i < a.Rows; i++ {
386		for j := 0; j < a.Cols; j++ {
387			ans.Data[j*ans.Stride+i] = a.Data[i*a.Stride+j]
388		}
389	}
390	return ans
391}
392
393// columnNorms returns the column norms of a.
394func columnNorms(m, n int, a []float64, lda int) []float64 {
395	bi := blas64.Implementation()
396	norms := make([]float64, n)
397	for j := 0; j < n; j++ {
398		norms[j] = bi.Dnrm2(m, a[j:], lda)
399	}
400	return norms
401}
402
403// extractVMat collects the single reflectors from a into a matrix.
404func extractVMat(m, n int, a []float64, lda int, direct lapack.Direct, store lapack.StoreV) blas64.General {
405	k := min(m, n)
406	switch {
407	default:
408		panic("not implemented")
409	case direct == lapack.Forward && store == lapack.ColumnWise:
410		v := blas64.General{
411			Rows:   m,
412			Cols:   k,
413			Stride: k,
414			Data:   make([]float64, m*k),
415		}
416		for i := 0; i < k; i++ {
417			for j := 0; j < i; j++ {
418				v.Data[j*v.Stride+i] = 0
419			}
420			v.Data[i*v.Stride+i] = 1
421			for j := i + 1; j < m; j++ {
422				v.Data[j*v.Stride+i] = a[j*lda+i]
423			}
424		}
425		return v
426	case direct == lapack.Forward && store == lapack.RowWise:
427		v := blas64.General{
428			Rows:   k,
429			Cols:   n,
430			Stride: n,
431			Data:   make([]float64, k*n),
432		}
433		for i := 0; i < k; i++ {
434			for j := 0; j < i; j++ {
435				v.Data[i*v.Stride+j] = 0
436			}
437			v.Data[i*v.Stride+i] = 1
438			for j := i + 1; j < n; j++ {
439				v.Data[i*v.Stride+j] = a[i*lda+j]
440			}
441		}
442		return v
443	}
444}
445
446// constructBidiagonal constructs a bidiagonal matrix with the given diagonal
447// and off-diagonal elements.
448func constructBidiagonal(uplo blas.Uplo, n int, d, e []float64) blas64.General {
449	bMat := blas64.General{
450		Rows:   n,
451		Cols:   n,
452		Stride: n,
453		Data:   make([]float64, n*n),
454	}
455
456	for i := 0; i < n-1; i++ {
457		bMat.Data[i*bMat.Stride+i] = d[i]
458		if uplo == blas.Upper {
459			bMat.Data[i*bMat.Stride+i+1] = e[i]
460		} else {
461			bMat.Data[(i+1)*bMat.Stride+i] = e[i]
462		}
463	}
464	bMat.Data[(n-1)*bMat.Stride+n-1] = d[n-1]
465	return bMat
466}
467
468// constructVMat transforms the v matrix based on the storage.
469func constructVMat(vMat blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
470	m := vMat.Rows
471	k := vMat.Cols
472	switch {
473	default:
474		panic("not implemented")
475	case store == lapack.ColumnWise && direct == lapack.Forward:
476		ldv := k
477		v := make([]float64, m*k)
478		for i := 0; i < m; i++ {
479			for j := 0; j < k; j++ {
480				if j > i {
481					v[i*ldv+j] = 0
482				} else if j == i {
483					v[i*ldv+i] = 1
484				} else {
485					v[i*ldv+j] = vMat.Data[i*vMat.Stride+j]
486				}
487			}
488		}
489		return blas64.General{
490			Rows:   m,
491			Cols:   k,
492			Stride: k,
493			Data:   v,
494		}
495	case store == lapack.RowWise && direct == lapack.Forward:
496		ldv := m
497		v := make([]float64, m*k)
498		for i := 0; i < m; i++ {
499			for j := 0; j < k; j++ {
500				if j > i {
501					v[j*ldv+i] = 0
502				} else if j == i {
503					v[j*ldv+i] = 1
504				} else {
505					v[j*ldv+i] = vMat.Data[i*vMat.Stride+j]
506				}
507			}
508		}
509		return blas64.General{
510			Rows:   k,
511			Cols:   m,
512			Stride: m,
513			Data:   v,
514		}
515	case store == lapack.ColumnWise && direct == lapack.Backward:
516		rowsv := m
517		ldv := k
518		v := make([]float64, m*k)
519		for i := 0; i < m; i++ {
520			for j := 0; j < k; j++ {
521				vrow := rowsv - i - 1
522				vcol := k - j - 1
523				if j > i {
524					v[vrow*ldv+vcol] = 0
525				} else if j == i {
526					v[vrow*ldv+vcol] = 1
527				} else {
528					v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
529				}
530			}
531		}
532		return blas64.General{
533			Rows:   rowsv,
534			Cols:   ldv,
535			Stride: ldv,
536			Data:   v,
537		}
538	case store == lapack.RowWise && direct == lapack.Backward:
539		rowsv := k
540		ldv := m
541		v := make([]float64, m*k)
542		for i := 0; i < m; i++ {
543			for j := 0; j < k; j++ {
544				vcol := ldv - i - 1
545				vrow := k - j - 1
546				if j > i {
547					v[vrow*ldv+vcol] = 0
548				} else if j == i {
549					v[vrow*ldv+vcol] = 1
550				} else {
551					v[vrow*ldv+vcol] = vMat.Data[i*vMat.Stride+j]
552				}
553			}
554		}
555		return blas64.General{
556			Rows:   rowsv,
557			Cols:   ldv,
558			Stride: ldv,
559			Data:   v,
560		}
561	}
562}
563
564func constructH(tau []float64, v blas64.General, store lapack.StoreV, direct lapack.Direct) blas64.General {
565	m := v.Rows
566	k := v.Cols
567	if store == lapack.RowWise {
568		m, k = k, m
569	}
570	h := blas64.General{
571		Rows:   m,
572		Cols:   m,
573		Stride: m,
574		Data:   make([]float64, m*m),
575	}
576	for i := 0; i < m; i++ {
577		h.Data[i*m+i] = 1
578	}
579	for i := 0; i < k; i++ {
580		vecData := make([]float64, m)
581		if store == lapack.ColumnWise {
582			for j := 0; j < m; j++ {
583				vecData[j] = v.Data[j*v.Cols+i]
584			}
585		} else {
586			for j := 0; j < m; j++ {
587				vecData[j] = v.Data[i*v.Cols+j]
588			}
589		}
590		vec := blas64.Vector{
591			Inc:  1,
592			Data: vecData,
593		}
594
595		hi := blas64.General{
596			Rows:   m,
597			Cols:   m,
598			Stride: m,
599			Data:   make([]float64, m*m),
600		}
601		for i := 0; i < m; i++ {
602			hi.Data[i*m+i] = 1
603		}
604		// hi = I - tau * v * vᵀ
605		blas64.Ger(-tau[i], vec, vec, hi)
606
607		hcopy := blas64.General{
608			Rows:   m,
609			Cols:   m,
610			Stride: m,
611			Data:   make([]float64, m*m),
612		}
613		copy(hcopy.Data, h.Data)
614		if direct == lapack.Forward {
615			// H = H * H_I in forward mode
616			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hcopy, hi, 0, h)
617		} else {
618			// H = H_I * H in backward mode
619			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hi, hcopy, 0, h)
620		}
621	}
622	return h
623}
624
625// constructQ constructs the Q matrix from the result of dgeqrf and dgeqr2.
626func constructQ(kind string, m, n int, a []float64, lda int, tau []float64) blas64.General {
627	k := min(m, n)
628	return constructQK(kind, m, n, k, a, lda, tau)
629}
630
631// constructQK constructs the Q matrix from the result of dgeqrf and dgeqr2 using
632// the first k reflectors.
633func constructQK(kind string, m, n, k int, a []float64, lda int, tau []float64) blas64.General {
634	var sz int
635	switch kind {
636	case "QR":
637		sz = m
638	case "LQ", "RQ":
639		sz = n
640	}
641
642	q := blas64.General{
643		Rows:   sz,
644		Cols:   sz,
645		Stride: max(1, sz),
646		Data:   make([]float64, sz*sz),
647	}
648	for i := 0; i < sz; i++ {
649		q.Data[i*sz+i] = 1
650	}
651	qCopy := blas64.General{
652		Rows:   q.Rows,
653		Cols:   q.Cols,
654		Stride: q.Stride,
655		Data:   make([]float64, len(q.Data)),
656	}
657	for i := 0; i < k; i++ {
658		h := blas64.General{
659			Rows:   sz,
660			Cols:   sz,
661			Stride: max(1, sz),
662			Data:   make([]float64, sz*sz),
663		}
664		for j := 0; j < sz; j++ {
665			h.Data[j*sz+j] = 1
666		}
667		vVec := blas64.Vector{
668			Inc:  1,
669			Data: make([]float64, sz),
670		}
671		switch kind {
672		case "QR":
673			vVec.Data[i] = 1
674			for j := i + 1; j < sz; j++ {
675				vVec.Data[j] = a[lda*j+i]
676			}
677		case "LQ":
678			vVec.Data[i] = 1
679			for j := i + 1; j < sz; j++ {
680				vVec.Data[j] = a[i*lda+j]
681			}
682		case "RQ":
683			for j := 0; j < n-k+i; j++ {
684				vVec.Data[j] = a[(m-k+i)*lda+j]
685			}
686			vVec.Data[n-k+i] = 1
687		}
688		blas64.Ger(-tau[i], vVec, vVec, h)
689		copy(qCopy.Data, q.Data)
690		// Multiply q by the new h.
691		switch kind {
692		case "QR", "RQ":
693			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
694		case "LQ":
695			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
696		}
697	}
698	return q
699}
700
701// checkBidiagonal checks the bidiagonal decomposition from dlabrd and dgebd2.
702// The input to this function is the answer returned from the routines, stored
703// in a, d, e, tauP, and tauQ. The data of original A matrix (before
704// decomposition) is input in aCopy.
705//
706// checkBidiagonal constructs the V and U matrices, and from them constructs Q
707// and P. Using these constructions, it checks that Qᵀ * A * P and checks that
708// the result is bidiagonal.
709func checkBidiagonal(t *testing.T, m, n, nb int, a []float64, lda int, d, e, tauP, tauQ, aCopy []float64) {
710	// Check the answer.
711	// Construct V and U.
712	qMat := constructQPBidiagonal(lapack.ApplyQ, m, n, nb, a, lda, tauQ)
713	pMat := constructQPBidiagonal(lapack.ApplyP, m, n, nb, a, lda, tauP)
714
715	// Compute Qᵀ * A * P.
716	aMat := blas64.General{
717		Rows:   m,
718		Cols:   n,
719		Stride: lda,
720		Data:   make([]float64, len(aCopy)),
721	}
722	copy(aMat.Data, aCopy)
723
724	tmp1 := blas64.General{
725		Rows:   m,
726		Cols:   n,
727		Stride: n,
728		Data:   make([]float64, m*n),
729	}
730	blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp1)
731	tmp2 := blas64.General{
732		Rows:   m,
733		Cols:   n,
734		Stride: n,
735		Data:   make([]float64, m*n),
736	}
737	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp1, pMat, 0, tmp2)
738
739	// Check that the first nb rows and cols of tm2 are upper bidiagonal
740	// if m >= n, and lower bidiagonal otherwise.
741	correctDiag := true
742	matchD := true
743	matchE := true
744	for i := 0; i < m; i++ {
745		for j := 0; j < n; j++ {
746			if i >= nb && j >= nb {
747				continue
748			}
749			v := tmp2.Data[i*tmp2.Stride+j]
750			if i == j {
751				if math.Abs(d[i]-v) > 1e-12 {
752					matchD = false
753				}
754				continue
755			}
756			if m >= n && i == j-1 {
757				if math.Abs(e[j-1]-v) > 1e-12 {
758					matchE = false
759				}
760				continue
761			}
762			if m < n && i-1 == j {
763				if math.Abs(e[i-1]-v) > 1e-12 {
764					matchE = false
765				}
766				continue
767			}
768			if math.Abs(v) > 1e-12 {
769				correctDiag = false
770			}
771		}
772	}
773	if !correctDiag {
774		t.Errorf("Updated A not bi-diagonal")
775	}
776	if !matchD {
777		fmt.Println("d = ", d)
778		t.Errorf("D Mismatch")
779	}
780	if !matchE {
781		t.Errorf("E mismatch")
782	}
783}
784
785// constructQPBidiagonal constructs Q or P from the Bidiagonal decomposition
786// computed by dlabrd and bgebd2.
787func constructQPBidiagonal(vect lapack.ApplyOrtho, m, n, nb int, a []float64, lda int, tau []float64) blas64.General {
788	sz := n
789	if vect == lapack.ApplyQ {
790		sz = m
791	}
792
793	var ldv int
794	var v blas64.General
795	if vect == lapack.ApplyQ {
796		ldv = nb
797		v = blas64.General{
798			Rows:   m,
799			Cols:   nb,
800			Stride: ldv,
801			Data:   make([]float64, m*ldv),
802		}
803	} else {
804		ldv = n
805		v = blas64.General{
806			Rows:   nb,
807			Cols:   n,
808			Stride: ldv,
809			Data:   make([]float64, m*ldv),
810		}
811	}
812
813	if vect == lapack.ApplyQ {
814		if m >= n {
815			for i := 0; i < m; i++ {
816				for j := 0; j <= min(nb-1, i); j++ {
817					if i == j {
818						v.Data[i*ldv+j] = 1
819						continue
820					}
821					v.Data[i*ldv+j] = a[i*lda+j]
822				}
823			}
824		} else {
825			for i := 1; i < m; i++ {
826				for j := 0; j <= min(nb-1, i-1); j++ {
827					if i-1 == j {
828						v.Data[i*ldv+j] = 1
829						continue
830					}
831					v.Data[i*ldv+j] = a[i*lda+j]
832				}
833			}
834		}
835	} else {
836		if m < n {
837			for i := 0; i < nb; i++ {
838				for j := i; j < n; j++ {
839					if i == j {
840						v.Data[i*ldv+j] = 1
841						continue
842					}
843					v.Data[i*ldv+j] = a[i*lda+j]
844				}
845			}
846		} else {
847			for i := 0; i < nb; i++ {
848				for j := i + 1; j < n; j++ {
849					if j-1 == i {
850						v.Data[i*ldv+j] = 1
851						continue
852					}
853					v.Data[i*ldv+j] = a[i*lda+j]
854				}
855			}
856		}
857	}
858
859	// The variable name is a computation of Q, but the algorithm is mostly the
860	// same for computing P (just with different data).
861	qMat := blas64.General{
862		Rows:   sz,
863		Cols:   sz,
864		Stride: sz,
865		Data:   make([]float64, sz*sz),
866	}
867	hMat := blas64.General{
868		Rows:   sz,
869		Cols:   sz,
870		Stride: sz,
871		Data:   make([]float64, sz*sz),
872	}
873	// set Q to I
874	for i := 0; i < sz; i++ {
875		qMat.Data[i*qMat.Stride+i] = 1
876	}
877	for i := 0; i < nb; i++ {
878		qCopy := blas64.General{Rows: qMat.Rows, Cols: qMat.Cols, Stride: qMat.Stride, Data: make([]float64, len(qMat.Data))}
879		copy(qCopy.Data, qMat.Data)
880
881		// Set g and h to I
882		for i := 0; i < sz; i++ {
883			for j := 0; j < sz; j++ {
884				if i == j {
885					hMat.Data[i*sz+j] = 1
886				} else {
887					hMat.Data[i*sz+j] = 0
888				}
889			}
890		}
891		var vi blas64.Vector
892		// H -= tauQ[i] * v[i] * v[i]^t
893		if vect == lapack.ApplyQ {
894			vi = blas64.Vector{
895				Inc:  v.Stride,
896				Data: v.Data[i:],
897			}
898		} else {
899			vi = blas64.Vector{
900				Inc:  1,
901				Data: v.Data[i*v.Stride:],
902			}
903		}
904		blas64.Ger(-tau[i], vi, vi, hMat)
905		// Q = Q * G[1]
906		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat)
907	}
908	return qMat
909}
910
911// printRowise prints the matrix with one row per line. This is useful for debugging.
912// If beyond is true, it prints beyond the final column to lda. If false, only
913// the columns are printed.
914//nolint:deadcode,unused
915func printRowise(a []float64, m, n, lda int, beyond bool) {
916	for i := 0; i < m; i++ {
917		end := n
918		if beyond {
919			end = lda
920		}
921		fmt.Println(a[i*lda : i*lda+end])
922	}
923}
924
925func copyGeneral(dst, src blas64.General) {
926	r := min(dst.Rows, src.Rows)
927	c := min(dst.Cols, src.Cols)
928	for i := 0; i < r; i++ {
929		copy(dst.Data[i*dst.Stride:i*dst.Stride+c], src.Data[i*src.Stride:i*src.Stride+c])
930	}
931}
932
933// cloneGeneral allocates and returns an exact copy of the given general matrix.
934func cloneGeneral(a blas64.General) blas64.General {
935	c := a
936	c.Data = make([]float64, len(a.Data))
937	copy(c.Data, a.Data)
938	return c
939}
940
941// equalGeneral returns whether the general matrices a and b are equal.
942func equalGeneral(a, b blas64.General) bool {
943	if a.Rows != b.Rows || a.Cols != b.Cols {
944		panic("bad input")
945	}
946	for i := 0; i < a.Rows; i++ {
947		for j := 0; j < a.Cols; j++ {
948			if a.Data[i*a.Stride+j] != b.Data[i*b.Stride+j] {
949				return false
950			}
951		}
952	}
953	return true
954}
955
956// equalApproxGeneral returns whether the general matrices a and b are
957// approximately equal within given tolerance.
958func equalApproxGeneral(a, b blas64.General, tol float64) bool {
959	if a.Rows != b.Rows || a.Cols != b.Cols {
960		panic("bad input")
961	}
962	for i := 0; i < a.Rows; i++ {
963		for j := 0; j < a.Cols; j++ {
964			diff := a.Data[i*a.Stride+j] - b.Data[i*b.Stride+j]
965			if math.IsNaN(diff) || math.Abs(diff) > tol {
966				return false
967			}
968		}
969	}
970	return true
971}
972
973// randSymBand returns an n×n random symmetric positive definite band matrix
974// with kd diagonals.
975func randSymBand(uplo blas.Uplo, n, kd, ldab int, rnd *rand.Rand) []float64 {
976	// Allocate a triangular band matrix U or L and fill it with random numbers.
977	var ab []float64
978	if n > 0 {
979		ab = make([]float64, (n-1)*ldab+kd+1)
980	}
981	for i := range ab {
982		ab[i] = rnd.NormFloat64()
983	}
984	// Make sure that the matrix U or L has a sufficiently positive diagonal.
985	switch uplo {
986	case blas.Upper:
987		for i := 0; i < n; i++ {
988			ab[i*ldab] = float64(n) + rnd.Float64()
989		}
990	case blas.Lower:
991		for i := 0; i < n; i++ {
992			ab[i*ldab+kd] = float64(n) + rnd.Float64()
993		}
994	}
995	// Compute Uᵀ*U or L*Lᵀ. The resulting (symmetric) matrix A will be
996	// positive definite and well-conditioned.
997	dsbmm(uplo, n, kd, ab, ldab)
998	return ab
999}
1000
1001// distSymBand returns the max-norm distance between the symmetric band matrices
1002// A and B.
1003func distSymBand(uplo blas.Uplo, n, kd int, a []float64, lda int, b []float64, ldb int) float64 {
1004	var dist float64
1005	switch uplo {
1006	case blas.Upper:
1007		for i := 0; i < n; i++ {
1008			for j := 0; j < min(kd+1, n-i); j++ {
1009				dist = math.Max(dist, math.Abs(a[i*lda+j]-b[i*ldb+j]))
1010			}
1011		}
1012	case blas.Lower:
1013		for i := 0; i < n; i++ {
1014			for j := max(0, kd-i); j < kd+1; j++ {
1015				dist = math.Max(dist, math.Abs(a[i*lda+j]-b[i*ldb+j]))
1016			}
1017		}
1018	}
1019	return dist
1020}
1021
1022// eye returns an identity matrix of given order and stride.
1023func eye(n, stride int) blas64.General {
1024	ans := nanGeneral(n, n, stride)
1025	for i := 0; i < n; i++ {
1026		for j := 0; j < n; j++ {
1027			ans.Data[i*ans.Stride+j] = 0
1028		}
1029		ans.Data[i*ans.Stride+i] = 1
1030	}
1031	return ans
1032}
1033
1034// zeros returns an m×n matrix with given stride filled with zeros.
1035func zeros(m, n, stride int) blas64.General {
1036	a := nanGeneral(m, n, stride)
1037	for i := 0; i < m; i++ {
1038		for j := 0; j < n; j++ {
1039			a.Data[i*a.Stride+j] = 0
1040		}
1041	}
1042	return a
1043}
1044
1045// extract2x2Block returns the elements of T at [0,0], [0,1], [1,0], and [1,1].
1046func extract2x2Block(t []float64, ldt int) (a, b, c, d float64) {
1047	return t[0], t[1], t[ldt], t[ldt+1]
1048}
1049
1050// isSchurCanonical returns whether the 2×2 matrix [a b; c d] is in Schur
1051// canonical form.
1052func isSchurCanonical(a, b, c, d float64) bool {
1053	return c == 0 || (b != 0 && a == d && math.Signbit(b) != math.Signbit(c))
1054}
1055
1056// isSchurCanonicalGeneral returns whether T is block upper triangular with 1×1
1057// and 2×2 diagonal blocks, each 2×2 block in Schur canonical form. The function
1058// checks only along the diagonal and the first subdiagonal, otherwise the lower
1059// triangle is not accessed.
1060func isSchurCanonicalGeneral(t blas64.General) bool {
1061	n := t.Cols
1062	if t.Rows != n {
1063		panic("invalid matrix")
1064	}
1065	for j := 0; j < n-1; {
1066		if t.Data[(j+1)*t.Stride+j] == 0 {
1067			// 1×1 block.
1068			for i := j + 1; i < n; i++ {
1069				if t.Data[i*t.Stride+j] != 0 {
1070					return false
1071				}
1072			}
1073			j++
1074			continue
1075		}
1076		// 2×2 block.
1077		a, b, c, d := extract2x2Block(t.Data[j*t.Stride+j:], t.Stride)
1078		if !isSchurCanonical(a, b, c, d) {
1079			return false
1080		}
1081		for i := j + 2; i < n; i++ {
1082			if t.Data[i*t.Stride+j] != 0 {
1083				return false
1084			}
1085		}
1086		for i := j + 2; i < n; i++ {
1087			if t.Data[i*t.Stride+j+1] != 0 {
1088				return false
1089			}
1090		}
1091		j += 2
1092	}
1093	return true
1094}
1095
1096// schurBlockEigenvalues returns the two eigenvalues of the 2×2 matrix [a b; c d]
1097// that must be in Schur canonical form.
1098func schurBlockEigenvalues(a, b, c, d float64) (ev1, ev2 complex128) {
1099	if !isSchurCanonical(a, b, c, d) {
1100		panic("block not in Schur canonical form")
1101	}
1102	if c == 0 {
1103		return complex(a, 0), complex(d, 0)
1104	}
1105	im := math.Sqrt(math.Abs(b)) * math.Sqrt(math.Abs(c))
1106	return complex(a, im), complex(a, -im)
1107}
1108
1109// schurBlockSize returns the size of the diagonal block at i-th row in the
1110// upper quasi-triangular matrix t in Schur canonical form, and whether i points
1111// to the first row of the block. For zero-sized matrices the function returns 0
1112// and true.
1113func schurBlockSize(t blas64.General, i int) (size int, first bool) {
1114	if t.Rows != t.Cols {
1115		panic("matrix not square")
1116	}
1117	if t.Rows == 0 {
1118		return 0, true
1119	}
1120	if i < 0 || t.Rows <= i {
1121		panic("index out of range")
1122	}
1123
1124	first = true
1125	if i > 0 && t.Data[i*t.Stride+i-1] != 0 {
1126		// There is a non-zero element to the left, therefore i must
1127		// point to the second row in a 2×2 diagonal block.
1128		first = false
1129		i--
1130	}
1131	size = 1
1132	if i+1 < t.Rows && t.Data[(i+1)*t.Stride+i] != 0 {
1133		// There is a non-zero element below, this must be a 2×2
1134		// diagonal block.
1135		size = 2
1136	}
1137	return size, first
1138}
1139
1140// containsComplex returns whether z is approximately equal to one of the complex
1141// numbers in v. If z is found, its index in v will be also returned.
1142func containsComplex(v []complex128, z complex128, tol float64) (found bool, index int) {
1143	for i := range v {
1144		if cmplx.Abs(v[i]-z) < tol {
1145			return true, i
1146		}
1147	}
1148	return false, -1
1149}
1150
1151// isAllNaN returns whether x contains only NaN values.
1152func isAllNaN(x []float64) bool {
1153	for _, v := range x {
1154		if !math.IsNaN(v) {
1155			return false
1156		}
1157	}
1158	return true
1159}
1160
1161// isUpperHessenberg returns whether h contains only zeros below the
1162// subdiagonal.
1163func isUpperHessenberg(h blas64.General) bool {
1164	if h.Rows != h.Cols {
1165		panic("matrix not square")
1166	}
1167	n := h.Rows
1168	for i := 0; i < n; i++ {
1169		for j := 0; j < n; j++ {
1170			if i > j+1 && h.Data[i*h.Stride+j] != 0 {
1171				return false
1172			}
1173		}
1174	}
1175	return true
1176}
1177
1178// isUpperTriangular returns whether a contains only zeros below the diagonal.
1179func isUpperTriangular(a blas64.General) bool {
1180	n := a.Rows
1181	for i := 1; i < n; i++ {
1182		for j := 0; j < i; j++ {
1183			if a.Data[i*a.Stride+j] != 0 {
1184				return false
1185			}
1186		}
1187	}
1188	return true
1189}
1190
1191// unbalancedSparseGeneral returns an m×n dense matrix with a random sparse
1192// structure consisting of nz nonzero elements. The matrix will be unbalanced by
1193// multiplying each element randomly by its row or column index.
1194func unbalancedSparseGeneral(m, n, stride int, nonzeros int, rnd *rand.Rand) blas64.General {
1195	a := zeros(m, n, stride)
1196	for k := 0; k < nonzeros; k++ {
1197		i := rnd.Intn(n)
1198		j := rnd.Intn(n)
1199		if rnd.Float64() < 0.5 {
1200			a.Data[i*stride+j] = float64(i+1) * rnd.NormFloat64()
1201		} else {
1202			a.Data[i*stride+j] = float64(j+1) * rnd.NormFloat64()
1203		}
1204	}
1205	return a
1206}
1207
1208// rootsOfUnity returns the n complex numbers whose n-th power is equal to 1.
1209func rootsOfUnity(n int) []complex128 {
1210	w := make([]complex128, n)
1211	for i := 0; i < n; i++ {
1212		angle := math.Pi * float64(2*i) / float64(n)
1213		w[i] = complex(math.Cos(angle), math.Sin(angle))
1214	}
1215	return w
1216}
1217
1218// constructGSVDresults returns the matrices [ 0 R ], D1 and D2 described
1219// in the documentation of Dtgsja and Dggsvd3, and the result matrix in
1220// the documentation for Dggsvp3.
1221func constructGSVDresults(n, p, m, k, l int, a, b blas64.General, alpha, beta []float64) (zeroR, d1, d2 blas64.General) {
1222	// [ 0 R ]
1223	zeroR = zeros(k+l, n, n)
1224	dst := zeroR
1225	dst.Rows = min(m, k+l)
1226	dst.Cols = k + l
1227	dst.Data = zeroR.Data[n-k-l:]
1228	src := a
1229	src.Rows = min(m, k+l)
1230	src.Cols = k + l
1231	src.Data = a.Data[n-k-l:]
1232	copyGeneral(dst, src)
1233	if m < k+l {
1234		// [ 0 R ]
1235		dst.Rows = k + l - m
1236		dst.Cols = k + l - m
1237		dst.Data = zeroR.Data[m*zeroR.Stride+n-(k+l-m):]
1238		src = b
1239		src.Rows = k + l - m
1240		src.Cols = k + l - m
1241		src.Data = b.Data[(m-k)*b.Stride+n+m-k-l:]
1242		copyGeneral(dst, src)
1243	}
1244
1245	// D1
1246	d1 = zeros(m, k+l, k+l)
1247	for i := 0; i < k; i++ {
1248		d1.Data[i*d1.Stride+i] = 1
1249	}
1250	for i := k; i < min(m, k+l); i++ {
1251		d1.Data[i*d1.Stride+i] = alpha[i]
1252	}
1253
1254	// D2
1255	d2 = zeros(p, k+l, k+l)
1256	for i := 0; i < min(l, m-k); i++ {
1257		d2.Data[i*d2.Stride+i+k] = beta[k+i]
1258	}
1259	for i := m - k; i < l; i++ {
1260		d2.Data[i*d2.Stride+i+k] = 1
1261	}
1262
1263	return zeroR, d1, d2
1264}
1265
1266func constructGSVPresults(n, p, m, k, l int, a, b blas64.General) (zeroA, zeroB blas64.General) {
1267	zeroA = zeros(m, n, n)
1268	dst := zeroA
1269	dst.Rows = min(m, k+l)
1270	dst.Cols = k + l
1271	dst.Data = zeroA.Data[n-k-l:]
1272	src := a
1273	dst.Rows = min(m, k+l)
1274	src.Cols = k + l
1275	src.Data = a.Data[n-k-l:]
1276	copyGeneral(dst, src)
1277
1278	zeroB = zeros(p, n, n)
1279	dst = zeroB
1280	dst.Rows = l
1281	dst.Cols = l
1282	dst.Data = zeroB.Data[n-l:]
1283	src = b
1284	dst.Rows = l
1285	src.Cols = l
1286	src.Data = b.Data[n-l:]
1287	copyGeneral(dst, src)
1288
1289	return zeroA, zeroB
1290}
1291
1292// distFromIdentity returns the L-infinity distance of an n×n matrix A from the
1293// identity. If A contains NaN elements, distFromIdentity will return +inf.
1294func distFromIdentity(n int, a []float64, lda int) float64 {
1295	var dist float64
1296	for i := 0; i < n; i++ {
1297		for j := 0; j < n; j++ {
1298			aij := a[i*lda+j]
1299			if math.IsNaN(aij) {
1300				return math.Inf(1)
1301			}
1302			if i == j {
1303				dist = math.Max(dist, math.Abs(aij-1))
1304			} else {
1305				dist = math.Max(dist, math.Abs(aij))
1306			}
1307		}
1308	}
1309	return dist
1310}
1311
1312func sameFloat64(a, b float64) bool {
1313	return a == b || math.IsNaN(a) && math.IsNaN(b)
1314}
1315
1316// sameLowerTri returns whether n×n matrices A and B are same under the diagonal.
1317func sameLowerTri(n int, a []float64, lda int, b []float64, ldb int) bool {
1318	for i := 1; i < n; i++ {
1319		for j := 0; j < i; j++ {
1320			aij := a[i*lda+j]
1321			bij := b[i*ldb+j]
1322			if !sameFloat64(aij, bij) {
1323				return false
1324			}
1325		}
1326	}
1327	return true
1328}
1329
1330// sameUpperTri returns whether n×n matrices A and B are same above the diagonal.
1331func sameUpperTri(n int, a []float64, lda int, b []float64, ldb int) bool {
1332	for i := 0; i < n-1; i++ {
1333		for j := i + 1; j < n; j++ {
1334			aij := a[i*lda+j]
1335			bij := b[i*ldb+j]
1336			if !sameFloat64(aij, bij) {
1337				return false
1338			}
1339		}
1340	}
1341	return true
1342}
1343
1344// svdJobString returns a string representation of job.
1345func svdJobString(job lapack.SVDJob) string {
1346	switch job {
1347	case lapack.SVDAll:
1348		return "All"
1349	case lapack.SVDStore:
1350		return "Store"
1351	case lapack.SVDOverwrite:
1352		return "Overwrite"
1353	case lapack.SVDNone:
1354		return "None"
1355	}
1356	return "unknown SVD job"
1357}
1358
1359// residualOrthogonal returns the residual
1360//  |I - Q * Qᵀ|  if m < n or (m == n and rowwise == true),
1361//  |I - Qᵀ * Q|  otherwise.
1362// It can be used to check that the matrix Q is orthogonal.
1363func residualOrthogonal(q blas64.General, rowwise bool) float64 {
1364	m, n := q.Rows, q.Cols
1365	if m == 0 || n == 0 {
1366		return 0
1367	}
1368	var transq blas.Transpose
1369	if m < n || (m == n && rowwise) {
1370		transq = blas.NoTrans
1371	} else {
1372		transq = blas.Trans
1373	}
1374	minmn := min(m, n)
1375
1376	// Set work = I.
1377	work := blas64.Symmetric{
1378		Uplo:   blas.Upper,
1379		N:      minmn,
1380		Data:   make([]float64, minmn*minmn),
1381		Stride: minmn,
1382	}
1383	for i := 0; i < minmn; i++ {
1384		work.Data[i*work.Stride+i] = 1
1385	}
1386
1387	// Compute
1388	//  work = work - Q * Qᵀ = I - Q * Qᵀ
1389	// or
1390	//  work = work - Qᵀ * Q = I - Qᵀ * Q
1391	blas64.Syrk(transq, -1, q, 1, work)
1392	return dlansy(lapack.MaxColumnSum, blas.Upper, work.N, work.Data, work.Stride)
1393}
1394