1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/blas64"
10)
11
12var (
13	symBandDense *SymBandDense
14	_            Matrix           = symBandDense
15	_            Symmetric        = symBandDense
16	_            Banded           = symBandDense
17	_            SymBanded        = symBandDense
18	_            RawSymBander     = symBandDense
19	_            MutableSymBanded = symBandDense
20
21	_ NonZeroDoer    = symBandDense
22	_ RowNonZeroDoer = symBandDense
23	_ ColNonZeroDoer = symBandDense
24)
25
26// SymBandDense represents a symmetric band matrix in dense storage format.
27type SymBandDense struct {
28	mat blas64.SymmetricBand
29}
30
31// SymBanded is a symmetric band matrix interface type.
32type SymBanded interface {
33	Banded
34
35	// Symmetric returns the number of rows/columns in the matrix.
36	Symmetric() int
37
38	// SymBand returns the number of rows/columns in the matrix, and the size of
39	// the bandwidth.
40	SymBand() (n, k int)
41}
42
43// MutableSymBanded is a symmetric band matrix interface type that allows elements
44// to be altered.
45type MutableSymBanded interface {
46	SymBanded
47	SetSymBand(i, j int, v float64)
48}
49
50// A RawSymBander can return a blas64.SymmetricBand representation of the receiver.
51// Changes to the blas64.SymmetricBand.Data slice will be reflected in the original
52// matrix, changes to the N, K, Stride and Uplo fields will not.
53type RawSymBander interface {
54	RawSymBand() blas64.SymmetricBand
55}
56
57// NewSymBandDense creates a new SymBand matrix with n rows and columns. If data == nil,
58// a new slice is allocated for the backing slice. If len(data) == n*(k+1),
59// data is used as the backing slice, and changes to the elements of the returned
60// SymBandDense will be reflected in data. If neither of these is true, NewSymBandDense
61// will panic. k must be at least zero and less than n, otherwise NewSymBandDense will panic.
62//
63// The data must be arranged in row-major order constructed by removing the zeros
64// from the rows outside the band and aligning the diagonals. SymBandDense matrices
65// are stored in the upper triangle. For example, the matrix
66//    1  2  3  0  0  0
67//    2  4  5  6  0  0
68//    3  5  7  8  9  0
69//    0  6  8 10 11 12
70//    0  0  9 11 13 14
71//    0  0  0 12 14 15
72// becomes (* entries are never accessed)
73//     1  2  3
74//     4  5  6
75//     7  8  9
76//    10 11 12
77//    13 14  *
78//    15  *  *
79// which is passed to NewSymBandDense as []float64{1, 2, ..., 15, *, *, *} with k=2.
80// Only the values in the band portion of the matrix are used.
81func NewSymBandDense(n, k int, data []float64) *SymBandDense {
82	if n <= 0 || k < 0 {
83		if n == 0 {
84			panic(ErrZeroLength)
85		}
86		panic("mat: negative dimension")
87	}
88	if k+1 > n {
89		panic("mat: band out of range")
90	}
91	bc := k + 1
92	if data != nil && len(data) != n*bc {
93		panic(ErrShape)
94	}
95	if data == nil {
96		data = make([]float64, n*bc)
97	}
98	return &SymBandDense{
99		mat: blas64.SymmetricBand{
100			N:      n,
101			K:      k,
102			Stride: bc,
103			Uplo:   blas.Upper,
104			Data:   data,
105		},
106	}
107}
108
109// Dims returns the number of rows and columns in the matrix.
110func (s *SymBandDense) Dims() (r, c int) {
111	return s.mat.N, s.mat.N
112}
113
114// Symmetric returns the size of the receiver.
115func (s *SymBandDense) Symmetric() int {
116	return s.mat.N
117}
118
119// Bandwidth returns the bandwidths of the matrix.
120func (s *SymBandDense) Bandwidth() (kl, ku int) {
121	return s.mat.K, s.mat.K
122}
123
124// SymBand returns the number of rows/columns in the matrix, and the size of
125// the bandwidth.
126func (s *SymBandDense) SymBand() (n, k int) {
127	return s.mat.N, s.mat.K
128}
129
130// T implements the Matrix interface. Symmetric matrices, by definition, are
131// equal to their transpose, and this is a no-op.
132func (s *SymBandDense) T() Matrix {
133	return s
134}
135
136// TBand implements the Banded interface.
137func (s *SymBandDense) TBand() Banded {
138	return s
139}
140
141// RawSymBand returns the underlying blas64.SymBand used by the receiver.
142// Changes to elements in the receiver following the call will be reflected
143// in returned blas64.SymBand.
144func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
145	return s.mat
146}
147
148// SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
149// Changes to elements in the receiver following the call will be reflected
150// in the input.
151//
152// The supplied SymmetricBand must use blas.Upper storage format.
153func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
154	if mat.Uplo != blas.Upper {
155		panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
156	}
157	s.mat = mat
158}
159
160// Zero sets all of the matrix elements to zero.
161func (s *SymBandDense) Zero() {
162	for i := 0; i < s.mat.N; i++ {
163		u := min(1+s.mat.K, s.mat.N-i)
164		zero(s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+u])
165	}
166}
167
168// DiagView returns the diagonal as a matrix backed by the original data.
169func (s *SymBandDense) DiagView() Diagonal {
170	n := s.mat.N
171	return &DiagDense{
172		mat: blas64.Vector{
173			N:    n,
174			Inc:  s.mat.Stride,
175			Data: s.mat.Data[:(n-1)*s.mat.Stride+1],
176		},
177	}
178}
179
180// DoNonZero calls the function fn for each of the non-zero elements of s. The function fn
181// takes a row/column index and the element value of s at (i, j).
182func (s *SymBandDense) DoNonZero(fn func(i, j int, v float64)) {
183	for i := 0; i < s.mat.N; i++ {
184		for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
185			v := s.at(i, j)
186			if v != 0 {
187				fn(i, j, v)
188			}
189		}
190	}
191}
192
193// DoRowNonZero calls the function fn for each of the non-zero elements of row i of s. The function fn
194// takes a row/column index and the element value of s at (i, j).
195func (s *SymBandDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
196	if i < 0 || s.mat.N <= i {
197		panic(ErrRowAccess)
198	}
199	for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
200		v := s.at(i, j)
201		if v != 0 {
202			fn(i, j, v)
203		}
204	}
205}
206
207// DoColNonZero calls the function fn for each of the non-zero elements of column j of s. The function fn
208// takes a row/column index and the element value of s at (i, j).
209func (s *SymBandDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
210	if j < 0 || s.mat.N <= j {
211		panic(ErrColAccess)
212	}
213	for i := 0; i < s.mat.N; i++ {
214		if i-s.mat.K <= j && j < i+s.mat.K+1 {
215			v := s.at(i, j)
216			if v != 0 {
217				fn(i, j, v)
218			}
219		}
220	}
221}
222