1// Copyright ©2020 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/lapack"
12)
13
14// Dlantb returns the value of the given norm of an n×n triangular band matrix A
15// with k+1 diagonals.
16//
17// When norm is lapack.MaxColumnSum, the length of work must be at least n.
18func (impl Implementation) Dlantb(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n, k int, a []float64, lda int, work []float64) float64 {
19	switch {
20	case norm != lapack.MaxAbs && norm != lapack.MaxRowSum && norm != lapack.MaxColumnSum && norm != lapack.Frobenius:
21		panic(badNorm)
22	case uplo != blas.Upper && uplo != blas.Lower:
23		panic(badUplo)
24	case n < 0:
25		panic(nLT0)
26	case k < 0:
27		panic(kdLT0)
28	case lda < k+1:
29		panic(badLdA)
30	}
31
32	// Quick return if possible.
33	if n == 0 {
34		return 0
35	}
36
37	switch {
38	case len(a) < (n-1)*lda+k+1:
39		panic(shortAB)
40	case len(work) < n && norm == lapack.MaxColumnSum:
41		panic(shortWork)
42	}
43
44	var value float64
45	switch norm {
46	case lapack.MaxAbs:
47		if uplo == blas.Upper {
48			var jfirst int
49			if diag == blas.Unit {
50				value = 1
51				jfirst = 1
52			}
53			for i := 0; i < n; i++ {
54				for _, aij := range a[i*lda+jfirst : i*lda+min(n-i, k+1)] {
55					if math.IsNaN(aij) {
56						return aij
57					}
58					aij = math.Abs(aij)
59					if aij > value {
60						value = aij
61					}
62				}
63			}
64		} else {
65			jlast := k + 1
66			if diag == blas.Unit {
67				value = 1
68				jlast = k
69			}
70			for i := 0; i < n; i++ {
71				for _, aij := range a[i*lda+max(0, k-i) : i*lda+jlast] {
72					if math.IsNaN(aij) {
73						return math.NaN()
74					}
75					aij = math.Abs(aij)
76					if aij > value {
77						value = aij
78					}
79				}
80			}
81		}
82	case lapack.MaxRowSum:
83		var sum float64
84		if uplo == blas.Upper {
85			var jfirst int
86			if diag == blas.Unit {
87				jfirst = 1
88			}
89			for i := 0; i < n; i++ {
90				sum = 0
91				if diag == blas.Unit {
92					sum = 1
93				}
94				for _, aij := range a[i*lda+jfirst : i*lda+min(n-i, k+1)] {
95					sum += math.Abs(aij)
96				}
97				if math.IsNaN(sum) {
98					return math.NaN()
99				}
100				if sum > value {
101					value = sum
102				}
103			}
104		} else {
105			jlast := k + 1
106			if diag == blas.Unit {
107				jlast = k
108			}
109			for i := 0; i < n; i++ {
110				sum = 0
111				if diag == blas.Unit {
112					sum = 1
113				}
114				for _, aij := range a[i*lda+max(0, k-i) : i*lda+jlast] {
115					sum += math.Abs(aij)
116				}
117				if math.IsNaN(sum) {
118					return math.NaN()
119				}
120				if sum > value {
121					value = sum
122				}
123			}
124		}
125	case lapack.MaxColumnSum:
126		work = work[:n]
127		if diag == blas.Unit {
128			for i := range work {
129				work[i] = 1
130			}
131		} else {
132			for i := range work {
133				work[i] = 0
134			}
135		}
136		if uplo == blas.Upper {
137			var jfirst int
138			if diag == blas.Unit {
139				jfirst = 1
140			}
141			for i := 0; i < n; i++ {
142				for j, aij := range a[i*lda+jfirst : i*lda+min(n-i, k+1)] {
143					work[i+jfirst+j] += math.Abs(aij)
144				}
145			}
146		} else {
147			jlast := k + 1
148			if diag == blas.Unit {
149				jlast = k
150			}
151			for i := 0; i < n; i++ {
152				off := max(0, k-i)
153				for j, aij := range a[i*lda+off : i*lda+jlast] {
154					work[i+j+off-k] += math.Abs(aij)
155				}
156			}
157		}
158		for _, wi := range work {
159			if math.IsNaN(wi) {
160				return math.NaN()
161			}
162			if wi > value {
163				value = wi
164			}
165		}
166	case lapack.Frobenius:
167		var scale, ssq float64
168		switch uplo {
169		case blas.Upper:
170			if diag == blas.Unit {
171				scale = 1
172				ssq = float64(n)
173				if k > 0 {
174					for i := 0; i < n-1; i++ {
175						ilen := min(n-i-1, k)
176						rowscale, rowssq := impl.Dlassq(ilen, a[i*lda+1:], 1, 0, 1)
177						scale, ssq = impl.Dcombssq(scale, ssq, rowscale, rowssq)
178					}
179				}
180			} else {
181				scale = 0
182				ssq = 1
183				for i := 0; i < n; i++ {
184					ilen := min(n-i, k+1)
185					rowscale, rowssq := impl.Dlassq(ilen, a[i*lda:], 1, 0, 1)
186					scale, ssq = impl.Dcombssq(scale, ssq, rowscale, rowssq)
187				}
188			}
189		case blas.Lower:
190			if diag == blas.Unit {
191				scale = 1
192				ssq = float64(n)
193				if k > 0 {
194					for i := 1; i < n; i++ {
195						ilen := min(i, k)
196						rowscale, rowssq := impl.Dlassq(ilen, a[i*lda+k-ilen:], 1, 0, 1)
197						scale, ssq = impl.Dcombssq(scale, ssq, rowscale, rowssq)
198					}
199				}
200			} else {
201				scale = 0
202				ssq = 1
203				for i := 0; i < n; i++ {
204					ilen := min(i, k) + 1
205					rowscale, rowssq := impl.Dlassq(ilen, a[i*lda+k+1-ilen:], 1, 0, 1)
206					scale, ssq = impl.Dcombssq(scale, ssq, rowscale, rowssq)
207				}
208			}
209		}
210		value = scale * math.Sqrt(ssq)
211	}
212	return value
213}
214