1// Copyright ©2018 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/blas64"
10)
11
12var (
13	diagDense *DiagDense
14	_         Matrix          = diagDense
15	_         Diagonal        = diagDense
16	_         MutableDiagonal = diagDense
17	_         Triangular      = diagDense
18	_         TriBanded       = diagDense
19	_         Symmetric       = diagDense
20	_         SymBanded       = diagDense
21	_         Banded          = diagDense
22	_         RawBander       = diagDense
23	_         RawSymBander    = diagDense
24
25	diag Diagonal
26	_    Matrix     = diag
27	_    Diagonal   = diag
28	_    Triangular = diag
29	_    TriBanded  = diag
30	_    Symmetric  = diag
31	_    SymBanded  = diag
32	_    Banded     = diag
33)
34
35// Diagonal represents a diagonal matrix, that is a square matrix that only
36// has non-zero terms on the diagonal.
37type Diagonal interface {
38	Matrix
39	// Diag returns the number of rows/columns in the matrix.
40	Diag() int
41
42	// Bandwidth and TBand are included in the Diagonal interface
43	// to allow the use of Diagonal types in banded functions.
44	// Bandwidth will always return (0, 0).
45	Bandwidth() (kl, ku int)
46	TBand() Banded
47
48	// Triangle and TTri are included in the Diagonal interface
49	// to allow the use of Diagonal types in triangular functions.
50	Triangle() (int, TriKind)
51	TTri() Triangular
52
53	// Symmetric and SymBand are included in the Diagonal interface
54	// to allow the use of Diagonal types in symmetric and banded symmetric
55	// functions respectively.
56	Symmetric() int
57	SymBand() (n, k int)
58
59	// TriBand and TTriBand are included in the Diagonal interface
60	// to allow the use of Diagonal types in triangular banded functions.
61	TriBand() (n, k int, kind TriKind)
62	TTriBand() TriBanded
63}
64
65// MutableDiagonal is a Diagonal matrix whose elements can be set.
66type MutableDiagonal interface {
67	Diagonal
68	SetDiag(i int, v float64)
69}
70
71// DiagDense represents a diagonal matrix in dense storage format.
72type DiagDense struct {
73	mat blas64.Vector
74}
75
76// NewDiagDense creates a new Diagonal matrix with n rows and n columns.
77// The length of data must be n or data must be nil, otherwise NewDiagDense
78// will panic. NewDiagDense will panic if n is zero.
79func NewDiagDense(n int, data []float64) *DiagDense {
80	if n <= 0 {
81		if n == 0 {
82			panic(ErrZeroLength)
83		}
84		panic("mat: negative dimension")
85	}
86	if data == nil {
87		data = make([]float64, n)
88	}
89	if len(data) != n {
90		panic(ErrShape)
91	}
92	return &DiagDense{
93		mat: blas64.Vector{N: n, Data: data, Inc: 1},
94	}
95}
96
97// Diag returns the dimension of the receiver.
98func (d *DiagDense) Diag() int {
99	return d.mat.N
100}
101
102// Dims returns the dimensions of the matrix.
103func (d *DiagDense) Dims() (r, c int) {
104	return d.mat.N, d.mat.N
105}
106
107// T returns the transpose of the matrix.
108func (d *DiagDense) T() Matrix {
109	return d
110}
111
112// TTri returns the transpose of the matrix. Note that Diagonal matrices are
113// Upper by default.
114func (d *DiagDense) TTri() Triangular {
115	return TransposeTri{d}
116}
117
118// TBand performs an implicit transpose by returning the receiver inside a
119// TransposeBand.
120func (d *DiagDense) TBand() Banded {
121	return TransposeBand{d}
122}
123
124// TTriBand performs an implicit transpose by returning the receiver inside a
125// TransposeTriBand. Note that Diagonal matrices are Upper by default.
126func (d *DiagDense) TTriBand() TriBanded {
127	return TransposeTriBand{d}
128}
129
130// Bandwidth returns the upper and lower bandwidths of the matrix.
131// These values are always zero for diagonal matrices.
132func (d *DiagDense) Bandwidth() (kl, ku int) {
133	return 0, 0
134}
135
136// Symmetric implements the Symmetric interface.
137func (d *DiagDense) Symmetric() int {
138	return d.mat.N
139}
140
141// SymBand returns the number of rows/columns in the matrix, and the size of
142// the bandwidth.
143func (d *DiagDense) SymBand() (n, k int) {
144	return d.mat.N, 0
145}
146
147// Triangle implements the Triangular interface.
148func (d *DiagDense) Triangle() (int, TriKind) {
149	return d.mat.N, Upper
150}
151
152// TriBand returns the number of rows/columns in the matrix, the
153// size of the bandwidth, and the orientation. Note that Diagonal matrices are
154// Upper by default.
155func (d *DiagDense) TriBand() (n, k int, kind TriKind) {
156	return d.mat.N, 0, Upper
157}
158
159// Reset zeros the length of the matrix so that it can be reused as the
160// receiver of a dimensionally restricted operation.
161//
162// See the Reseter interface for more information.
163func (d *DiagDense) Reset() {
164	// No change of Inc or n to 0 may be
165	// made unless both are set to 0.
166	d.mat.Inc = 0
167	d.mat.N = 0
168	d.mat.Data = d.mat.Data[:0]
169}
170
171// Zero sets all of the matrix elements to zero.
172func (d *DiagDense) Zero() {
173	for i := 0; i < d.mat.N; i++ {
174		d.mat.Data[d.mat.Inc*i] = 0
175	}
176}
177
178// DiagView returns the diagonal as a matrix backed by the original data.
179func (d *DiagDense) DiagView() Diagonal {
180	return d
181}
182
183// DiagFrom copies the diagonal of m into the receiver. The receiver must
184// be min(r, c) long or zero. Otherwise DiagFrom will panic.
185func (d *DiagDense) DiagFrom(m Matrix) {
186	n := min(m.Dims())
187	d.reuseAs(n)
188
189	var vec blas64.Vector
190	switch r := m.(type) {
191	case *DiagDense:
192		vec = r.mat
193	case RawBander:
194		mat := r.RawBand()
195		vec = blas64.Vector{
196			N:    n,
197			Inc:  mat.Stride,
198			Data: mat.Data[mat.KL : (n-1)*mat.Stride+mat.KL+1],
199		}
200	case RawMatrixer:
201		mat := r.RawMatrix()
202		vec = blas64.Vector{
203			N:    n,
204			Inc:  mat.Stride + 1,
205			Data: mat.Data[:(n-1)*mat.Stride+n],
206		}
207	case RawSymBander:
208		mat := r.RawSymBand()
209		vec = blas64.Vector{
210			N:    n,
211			Inc:  mat.Stride,
212			Data: mat.Data[:(n-1)*mat.Stride+1],
213		}
214	case RawSymmetricer:
215		mat := r.RawSymmetric()
216		vec = blas64.Vector{
217			N:    n,
218			Inc:  mat.Stride + 1,
219			Data: mat.Data[:(n-1)*mat.Stride+n],
220		}
221	case RawTriBander:
222		mat := r.RawTriBand()
223		data := mat.Data
224		if mat.Uplo == blas.Lower {
225			data = data[mat.K:]
226		}
227		vec = blas64.Vector{
228			N:    n,
229			Inc:  mat.Stride,
230			Data: data[:(n-1)*mat.Stride+1],
231		}
232	case RawTriangular:
233		mat := r.RawTriangular()
234		if mat.Diag == blas.Unit {
235			for i := 0; i < n; i += d.mat.Inc {
236				d.mat.Data[i] = 1
237			}
238			return
239		}
240		vec = blas64.Vector{
241			N:    n,
242			Inc:  mat.Stride + 1,
243			Data: mat.Data[:(n-1)*mat.Stride+n],
244		}
245	case RawVectorer:
246		d.mat.Data[0] = r.RawVector().Data[0]
247		return
248	default:
249		for i := 0; i < n; i++ {
250			d.setDiag(i, m.At(i, i))
251		}
252		return
253	}
254	blas64.Copy(vec, d.mat)
255}
256
257// RawBand returns the underlying data used by the receiver represented
258// as a blas64.Band.
259// Changes to elements in the receiver following the call will be reflected
260// in returned blas64.Band.
261func (d *DiagDense) RawBand() blas64.Band {
262	return blas64.Band{
263		Rows:   d.mat.N,
264		Cols:   d.mat.N,
265		KL:     0,
266		KU:     0,
267		Stride: d.mat.Inc,
268		Data:   d.mat.Data,
269	}
270}
271
272// RawSymBand returns the underlying data used by the receiver represented
273// as a blas64.SymmetricBand.
274// Changes to elements in the receiver following the call will be reflected
275// in returned blas64.Band.
276func (d *DiagDense) RawSymBand() blas64.SymmetricBand {
277	return blas64.SymmetricBand{
278		N:      d.mat.N,
279		K:      0,
280		Stride: d.mat.Inc,
281		Uplo:   blas.Upper,
282		Data:   d.mat.Data,
283	}
284}
285
286// reuseAs resizes an empty diagonal to a r×r diagonal,
287// or checks that a non-empty matrix is r×r.
288func (d *DiagDense) reuseAs(r int) {
289	if r == 0 {
290		panic(ErrZeroLength)
291	}
292	if d.IsZero() {
293		d.mat = blas64.Vector{
294			Inc:  1,
295			Data: use(d.mat.Data, r),
296		}
297		d.mat.N = r
298		return
299	}
300	if r != d.mat.N {
301		panic(ErrShape)
302	}
303}
304
305// IsZero returns whether the receiver is zero-sized. Zero-sized vectors can be the
306// receiver for size-restricted operations. DiagDenses can be zeroed using Reset.
307func (d *DiagDense) IsZero() bool {
308	// It must be the case that d.Dims() returns
309	// zeros in this case. See comment in Reset().
310	return d.mat.Inc == 0
311}
312