1// Copyright ©2016 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testlapack
6
7import (
8	"fmt"
9	"math"
10	"testing"
11
12	"golang.org/x/exp/rand"
13
14	"gonum.org/v1/gonum/blas"
15	"gonum.org/v1/gonum/blas/blas64"
16	"gonum.org/v1/gonum/lapack"
17)
18
19type Dtrevc3er interface {
20	Dtrevc3(side lapack.EVSide, howmny lapack.EVHowMany, selected []bool, n int, t []float64, ldt int, vl []float64, ldvl int, vr []float64, ldvr int, mm int, work []float64, lwork int) int
21}
22
23func Dtrevc3Test(t *testing.T, impl Dtrevc3er) {
24	rnd := rand.New(rand.NewSource(1))
25
26	for _, side := range []lapack.EVSide{lapack.EVRight, lapack.EVLeft, lapack.EVBoth} {
27		var name string
28		switch side {
29		case lapack.EVRight:
30			name = "EVRigth"
31		case lapack.EVLeft:
32			name = "EVLeft"
33		case lapack.EVBoth:
34			name = "EVBoth"
35		}
36		t.Run(name, func(t *testing.T) {
37			runDtrevc3Test(t, impl, rnd, side)
38		})
39	}
40}
41
42func runDtrevc3Test(t *testing.T, impl Dtrevc3er, rnd *rand.Rand, side lapack.EVSide) {
43	for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 7, 10, 34} {
44		for _, extra := range []int{0, 11} {
45			for _, optwork := range []bool{true, false} {
46				for cas := 0; cas < 10; cas++ {
47					dtrevc3Test(t, impl, side, n, extra, optwork, rnd)
48				}
49			}
50		}
51	}
52}
53
54// dtrevc3Test tests Dtrevc3 by generating a random matrix T in Schur canonical
55// form and performing the following checks:
56//  1. Compute all eigenvectors of T and check that they are indeed correctly
57//     normalized eigenvectors
58//  2. Compute selected eigenvectors and check that they are exactly equal to
59//     eigenvectors from check 1.
60//  3. Compute all eigenvectors multiplied into a matrix Q and check that the
61//     result is equal to eigenvectors from step 1 multiplied by Q and scaled
62//     appropriately.
63func dtrevc3Test(t *testing.T, impl Dtrevc3er, side lapack.EVSide, n, extra int, optwork bool, rnd *rand.Rand) {
64	const tol = 1e-15
65
66	name := fmt.Sprintf("n=%d,extra=%d,optwk=%v", n, extra, optwork)
67
68	right := side != lapack.EVLeft
69	left := side != lapack.EVRight
70
71	// Generate a random matrix in Schur canonical form possibly with tiny or zero eigenvalues.
72	// Zero elements of wi signify a real eigenvalue.
73	tmat, wr, wi := randomSchurCanonical(n, n+extra, true, rnd)
74	tmatCopy := cloneGeneral(tmat)
75
76	//  1. Compute all eigenvectors of T and check that they are indeed correctly
77	//     normalized eigenvectors
78
79	howmny := lapack.EVAll
80
81	var vr, vl blas64.General
82	if right {
83		// Fill VR and VL with NaN because they should be completely overwritten in Dtrevc3.
84		vr = nanGeneral(n, n, n+extra)
85	}
86	if left {
87		vl = nanGeneral(n, n, n+extra)
88	}
89
90	var work []float64
91	if optwork {
92		work = []float64{0}
93		impl.Dtrevc3(side, howmny, nil, n, tmat.Data, tmat.Stride,
94			vl.Data, max(1, vl.Stride), vr.Data, max(1, vr.Stride), n, work, -1)
95		work = make([]float64, int(work[0]))
96	} else {
97		work = make([]float64, max(1, 3*n))
98	}
99
100	mGot := impl.Dtrevc3(side, howmny, nil, n, tmat.Data, tmat.Stride,
101		vl.Data, max(1, vl.Stride), vr.Data, max(1, vr.Stride), n, work, len(work))
102
103	if !generalOutsideAllNaN(tmat) {
104		t.Errorf("%v: out-of-range write to T", name)
105	}
106	if !equalGeneral(tmat, tmatCopy) {
107		t.Errorf("%v: unexpected modification of T", name)
108	}
109	if !generalOutsideAllNaN(vr) {
110		t.Errorf("%v: out-of-range write to VR", name)
111	}
112	if !generalOutsideAllNaN(vl) {
113		t.Errorf("%v: out-of-range write to VL", name)
114	}
115
116	mWant := n
117	if mGot != mWant {
118		t.Errorf("%v: unexpected value of m=%d, want %d", name, mGot, mWant)
119	}
120
121	if right {
122		resid := residualRightEV(tmat, vr, wr, wi)
123		if resid > tol {
124			t.Errorf("%v: unexpected right eigenvectors; residual=%v, want<=%v", name, resid, tol)
125		}
126		resid = residualEVNormalization(vr, wi)
127		if resid > tol {
128			t.Errorf("%v: unexpected normalization of right eigenvectors; residual=%v, want<=%v", name, resid, tol)
129		}
130	}
131	if left {
132		resid := residualLeftEV(tmat, vl, wr, wi)
133		if resid > tol {
134			t.Errorf("%v: unexpected left eigenvectors; residual=%v, want<=%v", name, resid, tol)
135		}
136		resid = residualEVNormalization(vl, wi)
137		if resid > tol {
138			t.Errorf("%v: unexpected normalization of left eigenvectors; residual=%v, want<=%v", name, resid, tol)
139		}
140	}
141
142	//  2. Compute selected eigenvectors and check that they are exactly equal to
143	//     eigenvectors from check 1.
144
145	howmny = lapack.EVSelected
146
147	// Follow DCHKHS and select last max(1,n/4) real, max(1,n/4) complex
148	// eigenvectors instead of selecting them randomly.
149	selected := make([]bool, n)
150	selectedWant := make([]bool, n)
151	var nselr, nselc int
152	for j := n - 1; j > 0; {
153		if wi[j] == 0 {
154			if nselr < max(1, n/4) {
155				nselr++
156				selected[j] = true
157				selectedWant[j] = true
158			}
159			j--
160		} else {
161			if nselc < max(1, n/4) {
162				nselc++
163				// Select all columns to check that Dtrevc3 normalizes 'selected' correctly.
164				selected[j] = true
165				selected[j-1] = true
166				selectedWant[j] = false
167				selectedWant[j-1] = true
168			}
169			j -= 2
170		}
171	}
172	mWant = nselr + 2*nselc
173
174	var vrSel, vlSel blas64.General
175	if right {
176		vrSel = nanGeneral(n, mWant, n+extra)
177	}
178	if left {
179		vlSel = nanGeneral(n, mWant, n+extra)
180	}
181
182	if optwork {
183		// Reallocate optimal work in case it depends on howmny and selected.
184		work = []float64{0}
185		impl.Dtrevc3(side, howmny, selected, n, tmat.Data, tmat.Stride,
186			vlSel.Data, max(1, vlSel.Stride), vrSel.Data, max(1, vrSel.Stride), mWant, work, -1)
187		work = make([]float64, int(work[0]))
188	}
189
190	mGot = impl.Dtrevc3(side, howmny, selected, n, tmat.Data, tmat.Stride,
191		vlSel.Data, max(1, vlSel.Stride), vrSel.Data, max(1, vrSel.Stride), mWant, work, len(work))
192
193	if !generalOutsideAllNaN(tmat) {
194		t.Errorf("%v: out-of-range write to T", name)
195	}
196	if !equalGeneral(tmat, tmatCopy) {
197		t.Errorf("%v: unexpected modification of T", name)
198	}
199	if !generalOutsideAllNaN(vrSel) {
200		t.Errorf("%v: out-of-range write to selected VR", name)
201	}
202	if !generalOutsideAllNaN(vlSel) {
203		t.Errorf("%v: out-of-range write to selected VL", name)
204	}
205
206	if mGot != mWant {
207		t.Errorf("%v: unexpected value of selected m=%d, want %d", name, mGot, mWant)
208	}
209
210	for i := range selected {
211		if selected[i] != selectedWant[i] {
212			t.Errorf("%v: unexpected selected[%v]", name, i)
213		}
214	}
215
216	// Check that selected columns of vrSel are equal to the corresponding
217	// columns of vr.
218	var k int
219	match := true
220	if right {
221	loopVR:
222		for j := 0; j < n; j++ {
223			if selected[j] && wi[j] == 0 {
224				for i := 0; i < n; i++ {
225					if vrSel.Data[i*vrSel.Stride+k] != vr.Data[i*vr.Stride+j] {
226						match = false
227						break loopVR
228					}
229				}
230				k++
231			} else if selected[j] && wi[j] != 0 {
232				for i := 0; i < n; i++ {
233					if vrSel.Data[i*vrSel.Stride+k] != vr.Data[i*vr.Stride+j] ||
234						vrSel.Data[i*vrSel.Stride+k+1] != vr.Data[i*vr.Stride+j+1] {
235						match = false
236						break loopVR
237					}
238				}
239				k += 2
240			}
241		}
242	}
243	if !match {
244		t.Errorf("%v: unexpected selected VR", name)
245	}
246
247	// Check that selected columns of vlSel are equal to the corresponding
248	// columns of vl.
249	match = true
250	k = 0
251	if left {
252	loopVL:
253		for j := 0; j < n; j++ {
254			if selected[j] && wi[j] == 0 {
255				for i := 0; i < n; i++ {
256					if vlSel.Data[i*vlSel.Stride+k] != vl.Data[i*vl.Stride+j] {
257						match = false
258						break loopVL
259					}
260				}
261				k++
262			} else if selected[j] && wi[j] != 0 {
263				for i := 0; i < n; i++ {
264					if vlSel.Data[i*vlSel.Stride+k] != vl.Data[i*vl.Stride+j] ||
265						vlSel.Data[i*vlSel.Stride+k+1] != vl.Data[i*vl.Stride+j+1] {
266						match = false
267						break loopVL
268					}
269				}
270				k += 2
271			}
272		}
273	}
274	if !match {
275		t.Errorf("%v: unexpected selected VL", name)
276	}
277
278	//  3. Compute all eigenvectors multiplied into a matrix Q and check that the
279	//     result is equal to eigenvectors from step 1 multiplied by Q and scaled
280	//     appropriately.
281
282	howmny = lapack.EVAllMulQ
283
284	var vrMul, qr blas64.General
285	var vlMul, ql blas64.General
286	if right {
287		vrMul = randomGeneral(n, n, n+extra, rnd)
288		qr = cloneGeneral(vrMul)
289	}
290	if left {
291		vlMul = randomGeneral(n, n, n+extra, rnd)
292		ql = cloneGeneral(vlMul)
293	}
294
295	if optwork {
296		// Reallocate optimal work in case it depends on howmny and selected.
297		work = []float64{0}
298		impl.Dtrevc3(side, howmny, nil, n, tmat.Data, tmat.Stride,
299			vlMul.Data, max(1, vlMul.Stride), vrMul.Data, max(1, vrMul.Stride), n, work, -1)
300		work = make([]float64, int(work[0]))
301	}
302
303	mGot = impl.Dtrevc3(side, howmny, selected, n, tmat.Data, tmat.Stride,
304		vlMul.Data, max(1, vlMul.Stride), vrMul.Data, max(1, vrMul.Stride), n, work, len(work))
305
306	if !generalOutsideAllNaN(tmat) {
307		t.Errorf("%v: out-of-range write to T", name)
308	}
309	if !equalGeneral(tmat, tmatCopy) {
310		t.Errorf("%v: unexpected modification of T", name)
311	}
312	if !generalOutsideAllNaN(vrMul) {
313		t.Errorf("%v: out-of-range write to VRMul", name)
314	}
315	if !generalOutsideAllNaN(vlMul) {
316		t.Errorf("%v: out-of-range write to VLMul", name)
317	}
318
319	mWant = n
320	if mGot != mWant {
321		t.Errorf("%v: unexpected value of m=%d, want %d", name, mGot, mWant)
322	}
323
324	if right {
325		// Compute Q * VR explicitly and normalize to match Dtrevc3 output.
326		qvWant := zeros(n, n, n)
327		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qr, vr, 0, qvWant)
328		normalizeEV(qvWant, wi)
329
330		// Compute the difference between Dtrevc3 output and Q * VR.
331		r := zeros(n, n, n)
332		for i := 0; i < n; i++ {
333			for j := 0; j < n; j++ {
334				r.Data[i*r.Stride+j] = vrMul.Data[i*vrMul.Stride+j] - qvWant.Data[i*qvWant.Stride+j]
335			}
336		}
337		qvNorm := dlange(lapack.MaxColumnSum, n, n, qvWant.Data, qvWant.Stride)
338		resid := dlange(lapack.MaxColumnSum, n, n, r.Data, r.Stride) / qvNorm / float64(n)
339		if resid > tol {
340			t.Errorf("%v: unexpected VRMul; resid=%v, want <=%v", name, resid, tol)
341		}
342	}
343	if left {
344		// Compute Q * VL explicitly and normalize to match Dtrevc3 output.
345		qvWant := zeros(n, n, n)
346		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ql, vl, 0, qvWant)
347		normalizeEV(qvWant, wi)
348
349		// Compute the difference between Dtrevc3 output and Q * VL.
350		r := zeros(n, n, n)
351		for i := 0; i < n; i++ {
352			for j := 0; j < n; j++ {
353				r.Data[i*r.Stride+j] = vlMul.Data[i*vlMul.Stride+j] - qvWant.Data[i*qvWant.Stride+j]
354			}
355		}
356		qvNorm := dlange(lapack.MaxColumnSum, n, n, qvWant.Data, qvWant.Stride)
357		resid := dlange(lapack.MaxColumnSum, n, n, r.Data, r.Stride) / qvNorm / float64(n)
358		if resid > tol {
359			t.Errorf("%v: unexpected VLMul; resid=%v, want <=%v", name, resid, tol)
360		}
361	}
362}
363
364// residualEVNormalization returns the maximum normalization error in E:
365//  max |max-norm(E[:,j]) - 1|
366func residualEVNormalization(emat blas64.General, wi []float64) float64 {
367	n := emat.Rows
368	if n == 0 {
369		return 0
370	}
371	var (
372		e      = emat.Data
373		lde    = emat.Stride
374		enrmin = math.Inf(1)
375		enrmax float64
376		ipair  int
377	)
378	for j := 0; j < n; j++ {
379		if ipair == 0 && j < n-1 && wi[j] != 0 {
380			ipair = 1
381		}
382		var nrm float64
383		switch ipair {
384		case 0:
385			// Real eigenvector
386			for i := 0; i < n; i++ {
387				nrm = math.Max(nrm, math.Abs(e[i*lde+j]))
388			}
389			enrmin = math.Min(enrmin, nrm)
390			enrmax = math.Max(enrmax, nrm)
391		case 1:
392			// Complex eigenvector
393			for i := 0; i < n; i++ {
394				nrm = math.Max(nrm, math.Abs(e[i*lde+j])+math.Abs(e[i*lde+j+1]))
395			}
396			enrmin = math.Min(enrmin, nrm)
397			enrmax = math.Max(enrmax, nrm)
398			ipair = 2
399		case 2:
400			ipair = 0
401		}
402	}
403	return math.Max(math.Abs(enrmin-1), math.Abs(enrmin-1))
404}
405
406// normalizeEV normalizes eigenvectors in the columns of E so that the element
407// of largest magnitude has magnitude 1.
408func normalizeEV(emat blas64.General, wi []float64) {
409	n := emat.Rows
410	if n == 0 {
411		return
412	}
413	var (
414		bi    = blas64.Implementation()
415		e     = emat.Data
416		lde   = emat.Stride
417		ipair int
418	)
419	for j := 0; j < n; j++ {
420		if ipair == 0 && j < n-1 && wi[j] != 0 {
421			ipair = 1
422		}
423		switch ipair {
424		case 0:
425			// Real eigenvector
426			ii := bi.Idamax(n, e[j:], lde)
427			remax := 1 / math.Abs(e[ii*lde+j])
428			bi.Dscal(n, remax, e[j:], lde)
429		case 1:
430			// Complex eigenvector
431			var emax float64
432			for i := 0; i < n; i++ {
433				emax = math.Max(emax, math.Abs(e[i*lde+j])+math.Abs(e[i*lde+j+1]))
434			}
435			bi.Dscal(n, 1/emax, e[j:], lde)
436			bi.Dscal(n, 1/emax, e[j+1:], lde)
437			ipair = 2
438		case 2:
439			ipair = 0
440		}
441	}
442}
443