1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testlapack
6
7import (
8	"testing"
9
10	"golang.org/x/exp/rand"
11
12	"gonum.org/v1/gonum/blas"
13	"gonum.org/v1/gonum/blas/blas64"
14	"gonum.org/v1/gonum/floats/scalar"
15	"gonum.org/v1/gonum/lapack"
16)
17
18type Dggsvd3er interface {
19	Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool)
20}
21
22func Dggsvd3Test(t *testing.T, impl Dggsvd3er) {
23	const tol = 1e-13
24
25	rnd := rand.New(rand.NewSource(1))
26	for cas, test := range []struct {
27		m, p, n, lda, ldb, ldu, ldv, ldq int
28
29		ok bool
30	}{
31		{m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
32		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
33		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
34		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
35		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
36		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
37		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
38		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
39		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
40		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
41		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
42		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
43		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
44		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
45		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
46		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
47		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
48		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
49		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
50	} {
51		m := test.m
52		p := test.p
53		n := test.n
54		lda := test.lda
55		if lda == 0 {
56			lda = n
57		}
58		ldb := test.ldb
59		if ldb == 0 {
60			ldb = n
61		}
62		ldu := test.ldu
63		if ldu == 0 {
64			ldu = m
65		}
66		ldv := test.ldv
67		if ldv == 0 {
68			ldv = p
69		}
70		ldq := test.ldq
71		if ldq == 0 {
72			ldq = n
73		}
74
75		a := randomGeneral(m, n, lda, rnd)
76		aCopy := cloneGeneral(a)
77		b := randomGeneral(p, n, ldb, rnd)
78		bCopy := cloneGeneral(b)
79
80		alpha := make([]float64, n)
81		beta := make([]float64, n)
82
83		u := nanGeneral(m, m, ldu)
84		v := nanGeneral(p, p, ldv)
85		q := nanGeneral(n, n, ldq)
86
87		iwork := make([]int, n)
88
89		work := []float64{0}
90		impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
91			m, n, p,
92			a.Data, a.Stride,
93			b.Data, b.Stride,
94			alpha, beta,
95			u.Data, u.Stride,
96			v.Data, v.Stride,
97			q.Data, q.Stride,
98			work, -1, iwork)
99
100		lwork := int(work[0])
101		work = make([]float64, lwork)
102
103		k, l, ok := impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
104			m, n, p,
105			a.Data, a.Stride,
106			b.Data, b.Stride,
107			alpha, beta,
108			u.Data, u.Stride,
109			v.Data, v.Stride,
110			q.Data, q.Stride,
111			work, lwork, iwork)
112
113		if !ok {
114			if test.ok {
115				t.Errorf("test %d unexpectedly did not converge", cas)
116			}
117			continue
118		}
119
120		// Check orthogonality of U, V and Q.
121		if resid := residualOrthogonal(u, false); resid > tol {
122			t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
123		}
124		if resid := residualOrthogonal(v, false); resid > tol {
125			t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
126		}
127		if resid := residualOrthogonal(q, false); resid > tol {
128			t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
129		}
130
131		// Check C^2 + S^2 = I.
132		var elements []float64
133		if m-k-l >= 0 {
134			elements = alpha[k : k+l]
135		} else {
136			elements = alpha[k:m]
137		}
138		for i := range elements {
139			i += k
140			d := alpha[i]*alpha[i] + beta[i]*beta[i]
141			if !scalar.EqualWithinAbsOrRel(d, 1, tol, tol) {
142				t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d)
143			}
144		}
145
146		zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta)
147
148		// Check Uᵀ*A*Q = D1*[ 0 R ].
149		uTmp := nanGeneral(m, n, n)
150		blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
151		uAns := nanGeneral(m, n, n)
152		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
153
154		d10r := nanGeneral(m, n, n)
155		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)
156
157		if !equalApproxGeneral(uAns, d10r, tol) {
158			t.Errorf("test %d: Uᵀ*A*Q != D1*[ 0 R ]\nUᵀ*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
159				cas, uAns, d10r)
160		}
161
162		// Check Vᵀ*B*Q = D2*[ 0 R ].
163		vTmp := nanGeneral(p, n, n)
164		blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
165		vAns := nanGeneral(p, n, n)
166		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
167
168		d20r := nanGeneral(p, n, n)
169		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)
170
171		if !equalApproxGeneral(vAns, d20r, tol) {
172			t.Errorf("test %d: Vᵀ*B*Q != D2*[ 0 R ]\nVᵀ*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",
173				cas, vAns, d20r)
174		}
175	}
176}
177