1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/lapack"
12)
13
14// Dggsvp3 computes orthogonal matrices U, V and Q such that
15//
16//                  n-k-l  k    l
17//  Uᵀ*A*Q =     k [ 0    A12  A13 ] if m-k-l >= 0;
18//               l [ 0     0   A23 ]
19//           m-k-l [ 0     0    0  ]
20//
21//                  n-k-l  k    l
22//  Uᵀ*A*Q =     k [ 0    A12  A13 ] if m-k-l < 0;
23//             m-k [ 0     0   A23 ]
24//
25//                  n-k-l  k    l
26//  Vᵀ*B*Q =     l [ 0     0   B13 ]
27//             p-l [ 0     0    0  ]
28//
29// where the k×k matrix A12 and l×l matrix B13 are non-singular
30// upper triangular. A23 is l×l upper triangular if m-k-l >= 0,
31// otherwise A23 is (m-k)×l upper trapezoidal.
32//
33// Dggsvp3 returns k and l, the dimensions of the sub-blocks. k+l
34// is the effective numerical rank of the (m+p)×n matrix [ Aᵀ Bᵀ ]ᵀ.
35//
36// jobU, jobV and jobQ are options for computing the orthogonal matrices. The behavior
37// is as follows
38//  jobU == lapack.GSVDU        Compute orthogonal matrix U
39//  jobU == lapack.GSVDNone     Do not compute orthogonal matrix.
40// The behavior is the same for jobV and jobQ with the exception that instead of
41// lapack.GSVDU these accept lapack.GSVDV and lapack.GSVDQ respectively.
42// The matrices U, V and Q must be m×m, p×p and n×n respectively unless the
43// relevant job parameter is lapack.GSVDNone.
44//
45// tola and tolb are the convergence criteria for the Jacobi-Kogbetliantz
46// iteration procedure. Generally, they are the same as used in the preprocessing
47// step, for example,
48//  tola = max(m, n)*norm(A)*eps,
49//  tolb = max(p, n)*norm(B)*eps.
50// Where eps is the machine epsilon.
51//
52// iwork must have length n, work must have length at least max(1, lwork), and
53// lwork must be -1 or greater than zero, otherwise Dggsvp3 will panic.
54//
55// Dggsvp3 is an internal routine. It is exported for testing purposes.
56func (impl Implementation) Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int) {
57	wantu := jobU == lapack.GSVDU
58	wantv := jobV == lapack.GSVDV
59	wantq := jobQ == lapack.GSVDQ
60	switch {
61	case !wantu && jobU != lapack.GSVDNone:
62		panic(badGSVDJob + "U")
63	case !wantv && jobV != lapack.GSVDNone:
64		panic(badGSVDJob + "V")
65	case !wantq && jobQ != lapack.GSVDNone:
66		panic(badGSVDJob + "Q")
67	case m < 0:
68		panic(mLT0)
69	case p < 0:
70		panic(pLT0)
71	case n < 0:
72		panic(nLT0)
73	case lda < max(1, n):
74		panic(badLdA)
75	case ldb < max(1, n):
76		panic(badLdB)
77	case ldu < 1, wantu && ldu < m:
78		panic(badLdU)
79	case ldv < 1, wantv && ldv < p:
80		panic(badLdV)
81	case ldq < 1, wantq && ldq < n:
82		panic(badLdQ)
83	case len(iwork) != n:
84		panic(shortWork)
85	case lwork < 1 && lwork != -1:
86		panic(badLWork)
87	case len(work) < max(1, lwork):
88		panic(shortWork)
89	}
90
91	var lwkopt int
92	impl.Dgeqp3(p, n, b, ldb, iwork, tau, work, -1)
93	lwkopt = int(work[0])
94	if wantv {
95		lwkopt = max(lwkopt, p)
96	}
97	lwkopt = max(lwkopt, min(n, p))
98	lwkopt = max(lwkopt, m)
99	if wantq {
100		lwkopt = max(lwkopt, n)
101	}
102	impl.Dgeqp3(m, n, a, lda, iwork, tau, work, -1)
103	lwkopt = max(lwkopt, int(work[0]))
104	lwkopt = max(1, lwkopt)
105	if lwork == -1 {
106		work[0] = float64(lwkopt)
107		return 0, 0
108	}
109
110	switch {
111	case len(a) < (m-1)*lda+n:
112		panic(shortA)
113	case len(b) < (p-1)*ldb+n:
114		panic(shortB)
115	case wantu && len(u) < (m-1)*ldu+m:
116		panic(shortU)
117	case wantv && len(v) < (p-1)*ldv+p:
118		panic(shortV)
119	case wantq && len(q) < (n-1)*ldq+n:
120		panic(shortQ)
121	case len(tau) < n:
122		// tau check must come after lwkopt query since
123		// the Dggsvd3 call for lwkopt query may have
124		// lwork == -1, and tau is provided by work.
125		panic(shortTau)
126	}
127
128	const forward = true
129
130	// QR with column pivoting of B: B*P = V*[ S11 S12 ].
131	//                                       [  0   0  ]
132	for i := range iwork[:n] {
133		iwork[i] = 0
134	}
135	impl.Dgeqp3(p, n, b, ldb, iwork, tau, work, lwork)
136
137	// Update A := A*P.
138	impl.Dlapmt(forward, m, n, a, lda, iwork)
139
140	// Determine the effective rank of matrix B.
141	for i := 0; i < min(p, n); i++ {
142		if math.Abs(b[i*ldb+i]) > tolb {
143			l++
144		}
145	}
146
147	if wantv {
148		// Copy the details of V, and form V.
149		impl.Dlaset(blas.All, p, p, 0, 0, v, ldv)
150		if p > 1 {
151			impl.Dlacpy(blas.Lower, p-1, min(p, n), b[ldb:], ldb, v[ldv:], ldv)
152		}
153		impl.Dorg2r(p, p, min(p, n), v, ldv, tau, work)
154	}
155
156	// Clean up B.
157	for i := 1; i < l; i++ {
158		r := b[i*ldb : i*ldb+i]
159		for j := range r {
160			r[j] = 0
161		}
162	}
163	if p > l {
164		impl.Dlaset(blas.All, p-l, n, 0, 0, b[l*ldb:], ldb)
165	}
166
167	if wantq {
168		// Set Q = I and update Q := Q*P.
169		impl.Dlaset(blas.All, n, n, 0, 1, q, ldq)
170		impl.Dlapmt(forward, n, n, q, ldq, iwork)
171	}
172
173	if p >= l && n != l {
174		// RQ factorization of [ S11 S12 ]: [ S11 S12 ] = [ 0 S12 ]*Z.
175		impl.Dgerq2(l, n, b, ldb, tau, work)
176
177		// Update A := A*Zᵀ.
178		impl.Dormr2(blas.Right, blas.Trans, m, n, l, b, ldb, tau, a, lda, work)
179
180		if wantq {
181			// Update Q := Q*Zᵀ.
182			impl.Dormr2(blas.Right, blas.Trans, n, n, l, b, ldb, tau, q, ldq, work)
183		}
184
185		// Clean up B.
186		impl.Dlaset(blas.All, l, n-l, 0, 0, b, ldb)
187		for i := 1; i < l; i++ {
188			r := b[i*ldb+n-l : i*ldb+i+n-l]
189			for j := range r {
190				r[j] = 0
191			}
192		}
193	}
194
195	// Let              N-L     L
196	//            A = [ A11    A12 ] M,
197	//
198	// then the following does the complete QR decomposition of A11:
199	//
200	//          A11 = U*[  0  T12 ]*P1ᵀ.
201	//                  [  0   0  ]
202	for i := range iwork[:n-l] {
203		iwork[i] = 0
204	}
205	impl.Dgeqp3(m, n-l, a, lda, iwork[:n-l], tau, work, lwork)
206
207	// Determine the effective rank of A11.
208	for i := 0; i < min(m, n-l); i++ {
209		if math.Abs(a[i*lda+i]) > tola {
210			k++
211		}
212	}
213
214	// Update A12 := Uᵀ*A12, where A12 = A[0:m, n-l:n].
215	impl.Dorm2r(blas.Left, blas.Trans, m, l, min(m, n-l), a, lda, tau, a[n-l:], lda, work)
216
217	if wantu {
218		// Copy the details of U, and form U.
219		impl.Dlaset(blas.All, m, m, 0, 0, u, ldu)
220		if m > 1 {
221			impl.Dlacpy(blas.Lower, m-1, min(m, n-l), a[lda:], lda, u[ldu:], ldu)
222		}
223		impl.Dorg2r(m, m, min(m, n-l), u, ldu, tau, work)
224	}
225
226	if wantq {
227		// Update Q[0:n, 0:n-l] := Q[0:n, 0:n-l]*P1.
228		impl.Dlapmt(forward, n, n-l, q, ldq, iwork[:n-l])
229	}
230
231	// Clean up A: set the strictly lower triangular part of
232	// A[0:k, 0:k] = 0, and A[k:m, 0:n-l] = 0.
233	for i := 1; i < k; i++ {
234		r := a[i*lda : i*lda+i]
235		for j := range r {
236			r[j] = 0
237		}
238	}
239	if m > k {
240		impl.Dlaset(blas.All, m-k, n-l, 0, 0, a[k*lda:], lda)
241	}
242
243	if n-l > k {
244		// RQ factorization of [ T11 T12 ] = [ 0 T12 ]*Z1.
245		impl.Dgerq2(k, n-l, a, lda, tau, work)
246
247		if wantq {
248			// Update Q[0:n, 0:n-l] := Q[0:n, 0:n-l]*Z1ᵀ.
249			impl.Dorm2r(blas.Right, blas.Trans, n, n-l, k, a, lda, tau, q, ldq, work)
250		}
251
252		// Clean up A.
253		impl.Dlaset(blas.All, k, n-l-k, 0, 0, a, lda)
254		for i := 1; i < k; i++ {
255			r := a[i*lda+n-k-l : i*lda+i+n-k-l]
256			for j := range r {
257				a[j] = 0
258			}
259		}
260	}
261
262	if m > k {
263		// QR factorization of A[k:m, n-l:n].
264		impl.Dgeqr2(m-k, l, a[k*lda+n-l:], lda, tau, work)
265		if wantu {
266			// Update U[:, k:m) := U[:, k:m]*U1.
267			impl.Dorm2r(blas.Right, blas.NoTrans, m, m-k, min(m-k, l), a[k*lda+n-l:], lda, tau, u[k:], ldu, work)
268		}
269
270		// Clean up A.
271		for i := k + 1; i < m; i++ {
272			r := a[i*lda+n-l : i*lda+min(n-l+i-k, n)]
273			for j := range r {
274				r[j] = 0
275			}
276		}
277	}
278
279	work[0] = float64(lwkopt)
280	return k, l
281}
282