1// Copyright ©2016 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas/blas64"
11)
12
13// Dlasy2 solves the Sylvester matrix equation where the matrices are of order 1
14// or 2. It computes the unknown n1×n2 matrix X so that
15//  TL*X   + sgn*X*TR   = scale*B,  if tranl == false and tranr == false,
16//  TL^T*X + sgn*X*TR   = scale*B,  if tranl == true  and tranr == false,
17//  TL*X   + sgn*X*TR^T = scale*B,  if tranl == false and tranr == true,
18//  TL^T*X + sgn*X*TR^T = scale*B,  if tranl == true  and tranr == true,
19// where TL is n1×n1, TR is n2×n2, B is n1×n2, and 1 <= n1,n2 <= 2.
20//
21// isgn must be 1 or -1, and n1 and n2 must be 0, 1, or 2, but these conditions
22// are not checked.
23//
24// Dlasy2 returns three values, a scale factor that is chosen less than or equal
25// to 1 to prevent the solution overflowing, the infinity norm of the solution,
26// and an indicator of success. If ok is false, TL and TR have eigenvalues that
27// are too close, so TL or TR is perturbed to get a non-singular equation.
28//
29// Dlasy2 is an internal routine. It is exported for testing purposes.
30func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool) {
31	// TODO(vladimir-ch): Add input validation checks conditionally skipped
32	// using the build tag mechanism.
33
34	ok = true
35	// Quick return if possible.
36	if n1 == 0 || n2 == 0 {
37		return scale, xnorm, ok
38	}
39
40	// Set constants to control overflow.
41	eps := dlamchP
42	smlnum := dlamchS / eps
43	sgn := float64(isgn)
44
45	if n1 == 1 && n2 == 1 {
46		// 1×1 case: TL11*X + sgn*X*TR11 = B11.
47		tau1 := tl[0] + sgn*tr[0]
48		bet := math.Abs(tau1)
49		if bet <= smlnum {
50			tau1 = smlnum
51			bet = smlnum
52			ok = false
53		}
54		scale = 1
55		gam := math.Abs(b[0])
56		if smlnum*gam > bet {
57			scale = 1 / gam
58		}
59		x[0] = b[0] * scale / tau1
60		xnorm = math.Abs(x[0])
61		return scale, xnorm, ok
62	}
63
64	if n1+n2 == 3 {
65		// 1×2 or 2×1 case.
66		var (
67			smin float64
68			tmp  [4]float64 // tmp is used as a 2×2 row-major matrix.
69			btmp [2]float64
70		)
71		if n1 == 1 && n2 == 2 {
72			// 1×2 case: TL11*[X11 X12] + sgn*[X11 X12]*op[TR11 TR12] = [B11 B12].
73			//                                            [TR21 TR22]
74			smin = math.Abs(tl[0])
75			smin = math.Max(smin, math.Max(math.Abs(tr[0]), math.Abs(tr[1])))
76			smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
77			smin = math.Max(eps*smin, smlnum)
78			tmp[0] = tl[0] + sgn*tr[0]
79			tmp[3] = tl[0] + sgn*tr[ldtr+1]
80			if tranr {
81				tmp[1] = sgn * tr[1]
82				tmp[2] = sgn * tr[ldtr]
83			} else {
84				tmp[1] = sgn * tr[ldtr]
85				tmp[2] = sgn * tr[1]
86			}
87			btmp[0] = b[0]
88			btmp[1] = b[1]
89		} else {
90			// 2×1 case: op[TL11 TL12]*[X11] + sgn*[X11]*TR11 = [B11].
91			//             [TL21 TL22]*[X21]       [X21]        [B21]
92			smin = math.Abs(tr[0])
93			smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
94			smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
95			smin = math.Max(eps*smin, smlnum)
96			tmp[0] = tl[0] + sgn*tr[0]
97			tmp[3] = tl[ldtl+1] + sgn*tr[0]
98			if tranl {
99				tmp[1] = tl[ldtl]
100				tmp[2] = tl[1]
101			} else {
102				tmp[1] = tl[1]
103				tmp[2] = tl[ldtl]
104			}
105			btmp[0] = b[0]
106			btmp[1] = b[ldb]
107		}
108
109		// Solve 2×2 system using complete pivoting.
110		// Set pivots less than smin to smin.
111
112		bi := blas64.Implementation()
113		ipiv := bi.Idamax(len(tmp), tmp[:], 1)
114		// Compute the upper triangular matrix [u11 u12].
115		//                                     [  0 u22]
116		u11 := tmp[ipiv]
117		if math.Abs(u11) <= smin {
118			ok = false
119			u11 = smin
120		}
121		locu12 := [4]int{1, 0, 3, 2} // Index in tmp of the element on the same row as the pivot.
122		u12 := tmp[locu12[ipiv]]
123		locl21 := [4]int{2, 3, 0, 1} // Index in tmp of the element on the same column as the pivot.
124		l21 := tmp[locl21[ipiv]] / u11
125		locu22 := [4]int{3, 2, 1, 0} // Index in tmp of the remaining element.
126		u22 := tmp[locu22[ipiv]] - l21*u12
127		if math.Abs(u22) <= smin {
128			ok = false
129			u22 = smin
130		}
131		if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3.
132			// The pivot was in the second row, swap the elements of
133			// the right-hand side.
134			btmp[0], btmp[1] = btmp[1], btmp[0]-l21*btmp[1]
135		} else {
136			btmp[1] -= l21 * btmp[0]
137		}
138		scale = 1
139		if 2*smlnum*math.Abs(btmp[1]) > math.Abs(u22) || 2*smlnum*math.Abs(btmp[0]) > math.Abs(u11) {
140			scale = 0.5 / math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
141			btmp[0] *= scale
142			btmp[1] *= scale
143		}
144		// Solve the system [u11 u12] [x21] = [ btmp[0] ].
145		//                  [  0 u22] [x22]   [ btmp[1] ]
146		x22 := btmp[1] / u22
147		x21 := btmp[0]/u11 - (u12/u11)*x22
148		if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3.
149			// The pivot was in the second column, swap the elements
150			// of the solution.
151			x21, x22 = x22, x21
152		}
153		x[0] = x21
154		if n1 == 1 {
155			x[1] = x22
156			xnorm = math.Abs(x[0]) + math.Abs(x[1])
157		} else {
158			x[ldx] = x22
159			xnorm = math.Max(math.Abs(x[0]), math.Abs(x[ldx]))
160		}
161		return scale, xnorm, ok
162	}
163
164	// 2×2 case: op[TL11 TL12]*[X11 X12] + SGN*[X11 X12]*op[TR11 TR12] = [B11 B12].
165	//             [TL21 TL22] [X21 X22]       [X21 X22]   [TR21 TR22]   [B21 B22]
166	//
167	// Solve equivalent 4×4 system using complete pivoting.
168	// Set pivots less than smin to smin.
169
170	smin := math.Max(math.Abs(tr[0]), math.Abs(tr[1]))
171	smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
172	smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
173	smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
174	smin = math.Max(eps*smin, smlnum)
175
176	var t [4][4]float64
177	t[0][0] = tl[0] + sgn*tr[0]
178	t[1][1] = tl[0] + sgn*tr[ldtr+1]
179	t[2][2] = tl[ldtl+1] + sgn*tr[0]
180	t[3][3] = tl[ldtl+1] + sgn*tr[ldtr+1]
181	if tranl {
182		t[0][2] = tl[ldtl]
183		t[1][3] = tl[ldtl]
184		t[2][0] = tl[1]
185		t[3][1] = tl[1]
186	} else {
187		t[0][2] = tl[1]
188		t[1][3] = tl[1]
189		t[2][0] = tl[ldtl]
190		t[3][1] = tl[ldtl]
191	}
192	if tranr {
193		t[0][1] = sgn * tr[1]
194		t[1][0] = sgn * tr[ldtr]
195		t[2][3] = sgn * tr[1]
196		t[3][2] = sgn * tr[ldtr]
197	} else {
198		t[0][1] = sgn * tr[ldtr]
199		t[1][0] = sgn * tr[1]
200		t[2][3] = sgn * tr[ldtr]
201		t[3][2] = sgn * tr[1]
202	}
203
204	var btmp [4]float64
205	btmp[0] = b[0]
206	btmp[1] = b[1]
207	btmp[2] = b[ldb]
208	btmp[3] = b[ldb+1]
209
210	// Perform elimination.
211	var jpiv [4]int // jpiv records any column swaps for pivoting.
212	for i := 0; i < 3; i++ {
213		var (
214			xmax       float64
215			ipsv, jpsv int
216		)
217		for ip := i; ip < 4; ip++ {
218			for jp := i; jp < 4; jp++ {
219				if math.Abs(t[ip][jp]) >= xmax {
220					xmax = math.Abs(t[ip][jp])
221					ipsv = ip
222					jpsv = jp
223				}
224			}
225		}
226		if ipsv != i {
227			// The pivot is not in the top row of the unprocessed
228			// block, swap rows ipsv and i of t and btmp.
229			t[ipsv], t[i] = t[i], t[ipsv]
230			btmp[ipsv], btmp[i] = btmp[i], btmp[ipsv]
231		}
232		if jpsv != i {
233			// The pivot is not in the left column of the
234			// unprocessed block, swap columns jpsv and i of t.
235			for k := 0; k < 4; k++ {
236				t[k][jpsv], t[k][i] = t[k][i], t[k][jpsv]
237			}
238		}
239		jpiv[i] = jpsv
240		if math.Abs(t[i][i]) < smin {
241			ok = false
242			t[i][i] = smin
243		}
244		for k := i + 1; k < 4; k++ {
245			t[k][i] /= t[i][i]
246			btmp[k] -= t[k][i] * btmp[i]
247			for j := i + 1; j < 4; j++ {
248				t[k][j] -= t[k][i] * t[i][j]
249			}
250		}
251	}
252	if math.Abs(t[3][3]) < smin {
253		ok = false
254		t[3][3] = smin
255	}
256	scale = 1
257	if 8*smlnum*math.Abs(btmp[0]) > math.Abs(t[0][0]) ||
258		8*smlnum*math.Abs(btmp[1]) > math.Abs(t[1][1]) ||
259		8*smlnum*math.Abs(btmp[2]) > math.Abs(t[2][2]) ||
260		8*smlnum*math.Abs(btmp[3]) > math.Abs(t[3][3]) {
261
262		maxbtmp := math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
263		maxbtmp = math.Max(maxbtmp, math.Max(math.Abs(btmp[2]), math.Abs(btmp[3])))
264		scale = 1 / 8 / maxbtmp
265		btmp[0] *= scale
266		btmp[1] *= scale
267		btmp[2] *= scale
268		btmp[3] *= scale
269	}
270	// Compute the solution of the upper triangular system t * tmp = btmp.
271	var tmp [4]float64
272	for i := 3; i >= 0; i-- {
273		temp := 1 / t[i][i]
274		tmp[i] = btmp[i] * temp
275		for j := i + 1; j < 4; j++ {
276			tmp[i] -= temp * t[i][j] * tmp[j]
277		}
278	}
279	for i := 2; i >= 0; i-- {
280		if jpiv[i] != i {
281			tmp[i], tmp[jpiv[i]] = tmp[jpiv[i]], tmp[i]
282		}
283	}
284	x[0] = tmp[0]
285	x[1] = tmp[1]
286	x[ldx] = tmp[2]
287	x[ldx+1] = tmp[3]
288	xnorm = math.Max(math.Abs(tmp[0])+math.Abs(tmp[1]), math.Abs(tmp[2])+math.Abs(tmp[3]))
289	return scale, xnorm, ok
290}
291