1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testlapack
6
7import (
8	"testing"
9
10	"golang.org/x/exp/rand"
11
12	"gonum.org/v1/gonum/blas"
13	"gonum.org/v1/gonum/blas/blas64"
14	"gonum.org/v1/gonum/lapack"
15)
16
17type Dggsvp3er interface {
18	Dlanger
19	Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int)
20}
21
22func Dggsvp3Test(t *testing.T, impl Dggsvp3er) {
23	const tol = 1e-14
24
25	rnd := rand.New(rand.NewSource(1))
26	for cas, test := range []struct {
27		m, p, n, lda, ldb, ldu, ldv, ldq int
28	}{
29		{m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
30		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
31		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
32		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
33		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
34		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
35		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
36		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
37		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
38		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
39		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
40		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
41		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
42		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
43		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
44		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
45		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
46		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
47		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
48	} {
49		m := test.m
50		p := test.p
51		n := test.n
52		lda := test.lda
53		if lda == 0 {
54			lda = n
55		}
56		ldb := test.ldb
57		if ldb == 0 {
58			ldb = n
59		}
60		ldu := test.ldu
61		if ldu == 0 {
62			ldu = m
63		}
64		ldv := test.ldv
65		if ldv == 0 {
66			ldv = p
67		}
68		ldq := test.ldq
69		if ldq == 0 {
70			ldq = n
71		}
72
73		a := randomGeneral(m, n, lda, rnd)
74		aCopy := cloneGeneral(a)
75		b := randomGeneral(p, n, ldb, rnd)
76		bCopy := cloneGeneral(b)
77
78		tola := float64(max(m, n)) * impl.Dlange(lapack.Frobenius, m, n, a.Data, a.Stride, nil) * dlamchE
79		tolb := float64(max(p, n)) * impl.Dlange(lapack.Frobenius, p, n, b.Data, b.Stride, nil) * dlamchE
80
81		u := nanGeneral(m, m, ldu)
82		v := nanGeneral(p, p, ldv)
83		q := nanGeneral(n, n, ldq)
84
85		iwork := make([]int, n)
86		tau := make([]float64, n)
87
88		work := []float64{0}
89		impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
90			m, p, n,
91			a.Data, a.Stride,
92			b.Data, b.Stride,
93			tola, tolb,
94			u.Data, u.Stride,
95			v.Data, v.Stride,
96			q.Data, q.Stride,
97			iwork, tau,
98			work, -1)
99
100		lwork := int(work[0])
101		work = make([]float64, lwork)
102
103		k, l := impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
104			m, p, n,
105			a.Data, a.Stride,
106			b.Data, b.Stride,
107			tola, tolb,
108			u.Data, u.Stride,
109			v.Data, v.Stride,
110			q.Data, q.Stride,
111			iwork, tau,
112			work, lwork)
113
114		// Check orthogonality of U, V and Q.
115		if resid := residualOrthogonal(u, false); resid > tol {
116			t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
117		}
118		if resid := residualOrthogonal(v, false); resid > tol {
119			t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
120		}
121		if resid := residualOrthogonal(q, false); resid > tol {
122			t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
123		}
124
125		zeroA, zeroB := constructGSVPresults(n, p, m, k, l, a, b)
126
127		// Check Uᵀ*A*Q = [ 0 RA ].
128		uTmp := nanGeneral(m, n, n)
129		blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
130		uAns := nanGeneral(m, n, n)
131		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
132
133		if !equalApproxGeneral(uAns, zeroA, tol) {
134			t.Errorf("test %d: Uᵀ*A*Q != [ 0 RA ]\nUᵀ*A*Q:\n%+v\n[ 0 RA ]:\n%+v",
135				cas, uAns, zeroA)
136		}
137
138		// Check Vᵀ*B*Q = [ 0 RB ].
139		vTmp := nanGeneral(p, n, n)
140		blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
141		vAns := nanGeneral(p, n, n)
142		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
143
144		if !equalApproxGeneral(vAns, zeroB, tol) {
145			t.Errorf("test %d: Vᵀ*B*Q != [ 0 RB ]\nVᵀ*B*Q:\n%+v\n[ 0 RB ]:\n%+v",
146				cas, vAns, zeroB)
147		}
148	}
149}
150