1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/blas/blas64"
12	"gonum.org/v1/gonum/lapack"
13)
14
15// Dbdsqr performs a singular value decomposition of a real n×n bidiagonal matrix.
16//
17// The SVD of the bidiagonal matrix B is
18//  B = Q * S * P^T
19// where S is a diagonal matrix of singular values, Q is an orthogonal matrix of
20// left singular vectors, and P is an orthogonal matrix of right singular vectors.
21//
22// Q and P are only computed if requested. If left singular vectors are requested,
23// this routine returns U * Q instead of Q, and if right singular vectors are
24// requested P^T * VT is returned instead of P^T.
25//
26// Frequently Dbdsqr is used in conjunction with Dgebrd which reduces a general
27// matrix A into bidiagonal form. In this case, the SVD of A is
28//  A = (U * Q) * S * (P^T * VT)
29//
30// This routine may also compute Q^T * C.
31//
32// d and e contain the elements of the bidiagonal matrix b. d must have length at
33// least n, and e must have length at least n-1. Dbdsqr will panic if there is
34// insufficient length. On exit, D contains the singular values of B in decreasing
35// order.
36//
37// VT is a matrix of size n×ncvt whose elements are stored in vt. The elements
38// of vt are modified to contain P^T * VT on exit. VT is not used if ncvt == 0.
39//
40// U is a matrix of size nru×n whose elements are stored in u. The elements
41// of u are modified to contain U * Q on exit. U is not used if nru == 0.
42//
43// C is a matrix of size n×ncc whose elements are stored in c. The elements
44// of c are modified to contain Q^T * C on exit. C is not used if ncc == 0.
45//
46// work contains temporary storage and must have length at least 4*(n-1). Dbdsqr
47// will panic if there is insufficient working memory.
48//
49// Dbdsqr returns whether the decomposition was successful.
50//
51// Dbdsqr is an internal routine. It is exported for testing purposes.
52func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
53	switch {
54	case uplo != blas.Upper && uplo != blas.Lower:
55		panic(badUplo)
56	case n < 0:
57		panic(nLT0)
58	case ncvt < 0:
59		panic(ncvtLT0)
60	case nru < 0:
61		panic(nruLT0)
62	case ncc < 0:
63		panic(nccLT0)
64	case ldvt < max(1, ncvt):
65		panic(badLdVT)
66	case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0):
67		panic(badLdU)
68	case ldc < max(1, ncc):
69		panic(badLdC)
70	}
71
72	// Quick return if possible.
73	if n == 0 {
74		return true
75	}
76
77	if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 {
78		panic(shortVT)
79	}
80	if len(u) < (nru-1)*ldu+n && nru != 0 {
81		panic(shortU)
82	}
83	if len(c) < (n-1)*ldc+ncc && ncc != 0 {
84		panic(shortC)
85	}
86	if len(d) < n {
87		panic(shortD)
88	}
89	if len(e) < n-1 {
90		panic(shortE)
91	}
92	if len(work) < 4*(n-1) {
93		panic(shortWork)
94	}
95
96	var info int
97	bi := blas64.Implementation()
98	const maxIter = 6
99
100	if n != 1 {
101		// If the singular vectors do not need to be computed, use qd algorithm.
102		if !(ncvt > 0 || nru > 0 || ncc > 0) {
103			info = impl.Dlasq1(n, d, e, work)
104			// If info is 2 dqds didn't finish, and so try to.
105			if info != 2 {
106				return info == 0
107			}
108		}
109		nm1 := n - 1
110		nm12 := nm1 + nm1
111		nm13 := nm12 + nm1
112		idir := 0
113
114		eps := dlamchE
115		unfl := dlamchS
116		lower := uplo == blas.Lower
117		var cs, sn, r float64
118		if lower {
119			for i := 0; i < n-1; i++ {
120				cs, sn, r = impl.Dlartg(d[i], e[i])
121				d[i] = r
122				e[i] = sn * d[i+1]
123				d[i+1] *= cs
124				work[i] = cs
125				work[nm1+i] = sn
126			}
127			if nru > 0 {
128				impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, n, work, work[n-1:], u, ldu)
129			}
130			if ncc > 0 {
131				impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, n, ncc, work, work[n-1:], c, ldc)
132			}
133		}
134		// Compute singular values to a relative accuracy of tol. If tol is negative
135		// the values will be computed to an absolute accuracy of math.Abs(tol) * norm(b)
136		tolmul := math.Max(10, math.Min(100, math.Pow(eps, -1.0/8)))
137		tol := tolmul * eps
138		var smax float64
139		for i := 0; i < n; i++ {
140			smax = math.Max(smax, math.Abs(d[i]))
141		}
142		for i := 0; i < n-1; i++ {
143			smax = math.Max(smax, math.Abs(e[i]))
144		}
145
146		var sminl float64
147		var thresh float64
148		if tol >= 0 {
149			sminoa := math.Abs(d[0])
150			if sminoa != 0 {
151				mu := sminoa
152				for i := 1; i < n; i++ {
153					mu = math.Abs(d[i]) * (mu / (mu + math.Abs(e[i-1])))
154					sminoa = math.Min(sminoa, mu)
155					if sminoa == 0 {
156						break
157					}
158				}
159			}
160			sminoa = sminoa / math.Sqrt(float64(n))
161			thresh = math.Max(tol*sminoa, float64(maxIter*n*n)*unfl)
162		} else {
163			thresh = math.Max(math.Abs(tol)*smax, float64(maxIter*n*n)*unfl)
164		}
165		// Prepare for the main iteration loop for the singular values.
166		maxIt := maxIter * n * n
167		iter := 0
168		oldl2 := -1
169		oldm := -1
170		// m points to the last element of unconverged part of matrix.
171		m := n
172
173	Outer:
174		for m > 1 {
175			if iter > maxIt {
176				info = 0
177				for i := 0; i < n-1; i++ {
178					if e[i] != 0 {
179						info++
180					}
181				}
182				return info == 0
183			}
184			// Find diagonal block of matrix to work on.
185			if tol < 0 && math.Abs(d[m-1]) <= thresh {
186				d[m-1] = 0
187			}
188			smax = math.Abs(d[m-1])
189			smin := smax
190			var l2 int
191			var broke bool
192			for l3 := 0; l3 < m-1; l3++ {
193				l2 = m - l3 - 2
194				abss := math.Abs(d[l2])
195				abse := math.Abs(e[l2])
196				if tol < 0 && abss <= thresh {
197					d[l2] = 0
198				}
199				if abse <= thresh {
200					broke = true
201					break
202				}
203				smin = math.Min(smin, abss)
204				smax = math.Max(math.Max(smax, abss), abse)
205			}
206			if broke {
207				e[l2] = 0
208				if l2 == m-2 {
209					// Convergence of bottom singular value, return to top.
210					m--
211					continue
212				}
213				l2++
214			} else {
215				l2 = 0
216			}
217			// e[ll] through e[m-2] are nonzero, e[ll-1] is zero
218			if l2 == m-2 {
219				// Handle 2×2 block separately.
220				var sinr, cosr, sinl, cosl float64
221				d[m-1], d[m-2], sinr, cosr, sinl, cosl = impl.Dlasv2(d[m-2], e[m-2], d[m-1])
222				e[m-2] = 0
223				if ncvt > 0 {
224					bi.Drot(ncvt, vt[(m-2)*ldvt:], 1, vt[(m-1)*ldvt:], 1, cosr, sinr)
225				}
226				if nru > 0 {
227					bi.Drot(nru, u[m-2:], ldu, u[m-1:], ldu, cosl, sinl)
228				}
229				if ncc > 0 {
230					bi.Drot(ncc, c[(m-2)*ldc:], 1, c[(m-1)*ldc:], 1, cosl, sinl)
231				}
232				m -= 2
233				continue
234			}
235			// If working on a new submatrix, choose shift direction from larger end
236			// diagonal element toward smaller.
237			if l2 > oldm-1 || m-1 < oldl2 {
238				if math.Abs(d[l2]) >= math.Abs(d[m-1]) {
239					idir = 1
240				} else {
241					idir = 2
242				}
243			}
244			// Apply convergence tests.
245			// TODO(btracey): There is a lot of similar looking code here. See
246			// if there is a better way to de-duplicate.
247			if idir == 1 {
248				// Run convergence test in forward direction.
249				// First apply standard test to bottom of matrix.
250				if math.Abs(e[m-2]) <= math.Abs(tol)*math.Abs(d[m-1]) || (tol < 0 && math.Abs(e[m-2]) <= thresh) {
251					e[m-2] = 0
252					continue
253				}
254				if tol >= 0 {
255					// If relative accuracy desired, apply convergence criterion forward.
256					mu := math.Abs(d[l2])
257					sminl = mu
258					for l3 := l2; l3 < m-1; l3++ {
259						if math.Abs(e[l3]) <= tol*mu {
260							e[l3] = 0
261							continue Outer
262						}
263						mu = math.Abs(d[l3+1]) * (mu / (mu + math.Abs(e[l3])))
264						sminl = math.Min(sminl, mu)
265					}
266				}
267			} else {
268				// Run convergence test in backward direction.
269				// First apply standard test to top of matrix.
270				if math.Abs(e[l2]) <= math.Abs(tol)*math.Abs(d[l2]) || (tol < 0 && math.Abs(e[l2]) <= thresh) {
271					e[l2] = 0
272					continue
273				}
274				if tol >= 0 {
275					// If relative accuracy desired, apply convergence criterion backward.
276					mu := math.Abs(d[m-1])
277					sminl = mu
278					for l3 := m - 2; l3 >= l2; l3-- {
279						if math.Abs(e[l3]) <= tol*mu {
280							e[l3] = 0
281							continue Outer
282						}
283						mu = math.Abs(d[l3]) * (mu / (mu + math.Abs(e[l3])))
284						sminl = math.Min(sminl, mu)
285					}
286				}
287			}
288			oldl2 = l2
289			oldm = m
290			// Compute shift. First, test if shifting would ruin relative accuracy,
291			// and if so set the shift to zero.
292			var shift float64
293			if tol >= 0 && float64(n)*tol*(sminl/smax) <= math.Max(eps, (1.0/100)*tol) {
294				shift = 0
295			} else {
296				var sl2 float64
297				if idir == 1 {
298					sl2 = math.Abs(d[l2])
299					shift, _ = impl.Dlas2(d[m-2], e[m-2], d[m-1])
300				} else {
301					sl2 = math.Abs(d[m-1])
302					shift, _ = impl.Dlas2(d[l2], e[l2], d[l2+1])
303				}
304				// Test if shift is negligible
305				if sl2 > 0 {
306					if (shift/sl2)*(shift/sl2) < eps {
307						shift = 0
308					}
309				}
310			}
311			iter += m - l2 + 1
312			// If no shift, do simplified QR iteration.
313			if shift == 0 {
314				if idir == 1 {
315					cs := 1.0
316					oldcs := 1.0
317					var sn, r, oldsn float64
318					for i := l2; i < m-1; i++ {
319						cs, sn, r = impl.Dlartg(d[i]*cs, e[i])
320						if i > l2 {
321							e[i-1] = oldsn * r
322						}
323						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i+1]*sn)
324						work[i-l2] = cs
325						work[i-l2+nm1] = sn
326						work[i-l2+nm12] = oldcs
327						work[i-l2+nm13] = oldsn
328					}
329					h := d[m-1] * cs
330					d[m-1] = h * oldcs
331					e[m-2] = h * oldsn
332					if ncvt > 0 {
333						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
334					}
335					if nru > 0 {
336						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
337					}
338					if ncc > 0 {
339						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
340					}
341					if math.Abs(e[m-2]) < thresh {
342						e[m-2] = 0
343					}
344				} else {
345					cs := 1.0
346					oldcs := 1.0
347					var sn, r, oldsn float64
348					for i := m - 1; i >= l2+1; i-- {
349						cs, sn, r = impl.Dlartg(d[i]*cs, e[i-1])
350						if i < m-1 {
351							e[i] = oldsn * r
352						}
353						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i-1]*sn)
354						work[i-l2-1] = cs
355						work[i-l2+nm1-1] = -sn
356						work[i-l2+nm12-1] = oldcs
357						work[i-l2+nm13-1] = -oldsn
358					}
359					h := d[l2] * cs
360					d[l2] = h * oldcs
361					e[l2] = h * oldsn
362					if ncvt > 0 {
363						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
364					}
365					if nru > 0 {
366						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
367					}
368					if ncc > 0 {
369						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
370					}
371					if math.Abs(e[l2]) <= thresh {
372						e[l2] = 0
373					}
374				}
375			} else {
376				// Use nonzero shift.
377				if idir == 1 {
378					// Chase bulge from top to bottom. Save cosines and sines for
379					// later singular vector updates.
380					f := (math.Abs(d[l2]) - shift) * (math.Copysign(1, d[l2]) + shift/d[l2])
381					g := e[l2]
382					var cosl, sinl float64
383					for i := l2; i < m-1; i++ {
384						cosr, sinr, r := impl.Dlartg(f, g)
385						if i > l2 {
386							e[i-1] = r
387						}
388						f = cosr*d[i] + sinr*e[i]
389						e[i] = cosr*e[i] - sinr*d[i]
390						g = sinr * d[i+1]
391						d[i+1] *= cosr
392						cosl, sinl, r = impl.Dlartg(f, g)
393						d[i] = r
394						f = cosl*e[i] + sinl*d[i+1]
395						d[i+1] = cosl*d[i+1] - sinl*e[i]
396						if i < m-2 {
397							g = sinl * e[i+1]
398							e[i+1] = cosl * e[i+1]
399						}
400						work[i-l2] = cosr
401						work[i-l2+nm1] = sinr
402						work[i-l2+nm12] = cosl
403						work[i-l2+nm13] = sinl
404					}
405					e[m-2] = f
406					if ncvt > 0 {
407						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
408					}
409					if nru > 0 {
410						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
411					}
412					if ncc > 0 {
413						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
414					}
415					if math.Abs(e[m-2]) <= thresh {
416						e[m-2] = 0
417					}
418				} else {
419					// Chase bulge from top to bottom. Save cosines and sines for
420					// later singular vector updates.
421					f := (math.Abs(d[m-1]) - shift) * (math.Copysign(1, d[m-1]) + shift/d[m-1])
422					g := e[m-2]
423					for i := m - 1; i > l2; i-- {
424						cosr, sinr, r := impl.Dlartg(f, g)
425						if i < m-1 {
426							e[i] = r
427						}
428						f = cosr*d[i] + sinr*e[i-1]
429						e[i-1] = cosr*e[i-1] - sinr*d[i]
430						g = sinr * d[i-1]
431						d[i-1] *= cosr
432						cosl, sinl, r := impl.Dlartg(f, g)
433						d[i] = r
434						f = cosl*e[i-1] + sinl*d[i-1]
435						d[i-1] = cosl*d[i-1] - sinl*e[i-1]
436						if i > l2+1 {
437							g = sinl * e[i-2]
438							e[i-2] *= cosl
439						}
440						work[i-l2-1] = cosr
441						work[i-l2+nm1-1] = -sinr
442						work[i-l2+nm12-1] = cosl
443						work[i-l2+nm13-1] = -sinl
444					}
445					e[l2] = f
446					if math.Abs(e[l2]) <= thresh {
447						e[l2] = 0
448					}
449					if ncvt > 0 {
450						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
451					}
452					if nru > 0 {
453						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
454					}
455					if ncc > 0 {
456						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
457					}
458				}
459			}
460		}
461	}
462
463	// All singular values converged, make them positive.
464	for i := 0; i < n; i++ {
465		if d[i] < 0 {
466			d[i] *= -1
467			if ncvt > 0 {
468				bi.Dscal(ncvt, -1, vt[i*ldvt:], 1)
469			}
470		}
471	}
472
473	// Sort the singular values in decreasing order.
474	for i := 0; i < n-1; i++ {
475		isub := 0
476		smin := d[0]
477		for j := 1; j < n-i; j++ {
478			if d[j] <= smin {
479				isub = j
480				smin = d[j]
481			}
482		}
483		if isub != n-i {
484			// Swap singular values and vectors.
485			d[isub] = d[n-i-1]
486			d[n-i-1] = smin
487			if ncvt > 0 {
488				bi.Dswap(ncvt, vt[isub*ldvt:], 1, vt[(n-i-1)*ldvt:], 1)
489			}
490			if nru > 0 {
491				bi.Dswap(nru, u[isub:], ldu, u[n-i-1:], ldu)
492			}
493			if ncc > 0 {
494				bi.Dswap(ncc, c[isub*ldc:], 1, c[(n-i-1)*ldc:], 1)
495			}
496		}
497	}
498	info = 0
499	for i := 0; i < n-1; i++ {
500		if e[i] != 0 {
501			info++
502		}
503	}
504	return info == 0
505}
506