1// Copyright ©2018 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/blas64" 10) 11 12var ( 13 diagDense *DiagDense 14 _ Matrix = diagDense 15 _ Diagonal = diagDense 16 _ MutableDiagonal = diagDense 17 _ Triangular = diagDense 18 _ TriBanded = diagDense 19 _ Symmetric = diagDense 20 _ SymBanded = diagDense 21 _ Banded = diagDense 22 _ RawBander = diagDense 23 _ RawSymBander = diagDense 24 25 diag Diagonal 26 _ Matrix = diag 27 _ Diagonal = diag 28 _ Triangular = diag 29 _ TriBanded = diag 30 _ Symmetric = diag 31 _ SymBanded = diag 32 _ Banded = diag 33) 34 35// Diagonal represents a diagonal matrix, that is a square matrix that only 36// has non-zero terms on the diagonal. 37type Diagonal interface { 38 Matrix 39 // Diag returns the number of rows/columns in the matrix. 40 Diag() int 41 42 // Bandwidth and TBand are included in the Diagonal interface 43 // to allow the use of Diagonal types in banded functions. 44 // Bandwidth will always return (0, 0). 45 Bandwidth() (kl, ku int) 46 TBand() Banded 47 48 // Triangle and TTri are included in the Diagonal interface 49 // to allow the use of Diagonal types in triangular functions. 50 Triangle() (int, TriKind) 51 TTri() Triangular 52 53 // Symmetric and SymBand are included in the Diagonal interface 54 // to allow the use of Diagonal types in symmetric and banded symmetric 55 // functions respectively. 56 Symmetric() int 57 SymBand() (n, k int) 58 59 // TriBand and TTriBand are included in the Diagonal interface 60 // to allow the use of Diagonal types in triangular banded functions. 61 TriBand() (n, k int, kind TriKind) 62 TTriBand() TriBanded 63} 64 65// MutableDiagonal is a Diagonal matrix whose elements can be set. 66type MutableDiagonal interface { 67 Diagonal 68 SetDiag(i int, v float64) 69} 70 71// DiagDense represents a diagonal matrix in dense storage format. 72type DiagDense struct { 73 mat blas64.Vector 74} 75 76// NewDiagDense creates a new Diagonal matrix with n rows and n columns. 77// The length of data must be n or data must be nil, otherwise NewDiagDense 78// will panic. NewDiagDense will panic if n is zero. 79func NewDiagDense(n int, data []float64) *DiagDense { 80 if n <= 0 { 81 if n == 0 { 82 panic(ErrZeroLength) 83 } 84 panic("mat: negative dimension") 85 } 86 if data == nil { 87 data = make([]float64, n) 88 } 89 if len(data) != n { 90 panic(ErrShape) 91 } 92 return &DiagDense{ 93 mat: blas64.Vector{N: n, Data: data, Inc: 1}, 94 } 95} 96 97// Diag returns the dimension of the receiver. 98func (d *DiagDense) Diag() int { 99 return d.mat.N 100} 101 102// Dims returns the dimensions of the matrix. 103func (d *DiagDense) Dims() (r, c int) { 104 return d.mat.N, d.mat.N 105} 106 107// T returns the transpose of the matrix. 108func (d *DiagDense) T() Matrix { 109 return d 110} 111 112// TTri returns the transpose of the matrix. Note that Diagonal matrices are 113// Upper by default. 114func (d *DiagDense) TTri() Triangular { 115 return TransposeTri{d} 116} 117 118// TBand performs an implicit transpose by returning the receiver inside a 119// TransposeBand. 120func (d *DiagDense) TBand() Banded { 121 return TransposeBand{d} 122} 123 124// TTriBand performs an implicit transpose by returning the receiver inside a 125// TransposeTriBand. Note that Diagonal matrices are Upper by default. 126func (d *DiagDense) TTriBand() TriBanded { 127 return TransposeTriBand{d} 128} 129 130// Bandwidth returns the upper and lower bandwidths of the matrix. 131// These values are always zero for diagonal matrices. 132func (d *DiagDense) Bandwidth() (kl, ku int) { 133 return 0, 0 134} 135 136// Symmetric implements the Symmetric interface. 137func (d *DiagDense) Symmetric() int { 138 return d.mat.N 139} 140 141// SymBand returns the number of rows/columns in the matrix, and the size of 142// the bandwidth. 143func (d *DiagDense) SymBand() (n, k int) { 144 return d.mat.N, 0 145} 146 147// Triangle implements the Triangular interface. 148func (d *DiagDense) Triangle() (int, TriKind) { 149 return d.mat.N, Upper 150} 151 152// TriBand returns the number of rows/columns in the matrix, the 153// size of the bandwidth, and the orientation. Note that Diagonal matrices are 154// Upper by default. 155func (d *DiagDense) TriBand() (n, k int, kind TriKind) { 156 return d.mat.N, 0, Upper 157} 158 159// Reset zeros the length of the matrix so that it can be reused as the 160// receiver of a dimensionally restricted operation. 161// 162// See the Reseter interface for more information. 163func (d *DiagDense) Reset() { 164 // No change of Inc or n to 0 may be 165 // made unless both are set to 0. 166 d.mat.Inc = 0 167 d.mat.N = 0 168 d.mat.Data = d.mat.Data[:0] 169} 170 171// Zero sets all of the matrix elements to zero. 172func (d *DiagDense) Zero() { 173 for i := 0; i < d.mat.N; i++ { 174 d.mat.Data[d.mat.Inc*i] = 0 175 } 176} 177 178// DiagView returns the diagonal as a matrix backed by the original data. 179func (d *DiagDense) DiagView() Diagonal { 180 return d 181} 182 183// DiagFrom copies the diagonal of m into the receiver. The receiver must 184// be min(r, c) long or zero. Otherwise DiagFrom will panic. 185func (d *DiagDense) DiagFrom(m Matrix) { 186 n := min(m.Dims()) 187 d.reuseAs(n) 188 189 var vec blas64.Vector 190 switch r := m.(type) { 191 case *DiagDense: 192 vec = r.mat 193 case RawBander: 194 mat := r.RawBand() 195 vec = blas64.Vector{ 196 N: n, 197 Inc: mat.Stride, 198 Data: mat.Data[mat.KL : (n-1)*mat.Stride+mat.KL+1], 199 } 200 case RawMatrixer: 201 mat := r.RawMatrix() 202 vec = blas64.Vector{ 203 N: n, 204 Inc: mat.Stride + 1, 205 Data: mat.Data[:(n-1)*mat.Stride+n], 206 } 207 case RawSymBander: 208 mat := r.RawSymBand() 209 vec = blas64.Vector{ 210 N: n, 211 Inc: mat.Stride, 212 Data: mat.Data[:(n-1)*mat.Stride+1], 213 } 214 case RawSymmetricer: 215 mat := r.RawSymmetric() 216 vec = blas64.Vector{ 217 N: n, 218 Inc: mat.Stride + 1, 219 Data: mat.Data[:(n-1)*mat.Stride+n], 220 } 221 case RawTriBander: 222 mat := r.RawTriBand() 223 data := mat.Data 224 if mat.Uplo == blas.Lower { 225 data = data[mat.K:] 226 } 227 vec = blas64.Vector{ 228 N: n, 229 Inc: mat.Stride, 230 Data: data[:(n-1)*mat.Stride+1], 231 } 232 case RawTriangular: 233 mat := r.RawTriangular() 234 if mat.Diag == blas.Unit { 235 for i := 0; i < n; i += d.mat.Inc { 236 d.mat.Data[i] = 1 237 } 238 return 239 } 240 vec = blas64.Vector{ 241 N: n, 242 Inc: mat.Stride + 1, 243 Data: mat.Data[:(n-1)*mat.Stride+n], 244 } 245 case RawVectorer: 246 d.mat.Data[0] = r.RawVector().Data[0] 247 return 248 default: 249 for i := 0; i < n; i++ { 250 d.setDiag(i, m.At(i, i)) 251 } 252 return 253 } 254 blas64.Copy(vec, d.mat) 255} 256 257// RawBand returns the underlying data used by the receiver represented 258// as a blas64.Band. 259// Changes to elements in the receiver following the call will be reflected 260// in returned blas64.Band. 261func (d *DiagDense) RawBand() blas64.Band { 262 return blas64.Band{ 263 Rows: d.mat.N, 264 Cols: d.mat.N, 265 KL: 0, 266 KU: 0, 267 Stride: d.mat.Inc, 268 Data: d.mat.Data, 269 } 270} 271 272// RawSymBand returns the underlying data used by the receiver represented 273// as a blas64.SymmetricBand. 274// Changes to elements in the receiver following the call will be reflected 275// in returned blas64.Band. 276func (d *DiagDense) RawSymBand() blas64.SymmetricBand { 277 return blas64.SymmetricBand{ 278 N: d.mat.N, 279 K: 0, 280 Stride: d.mat.Inc, 281 Uplo: blas.Upper, 282 Data: d.mat.Data, 283 } 284} 285 286// reuseAs resizes an empty diagonal to a r×r diagonal, 287// or checks that a non-empty matrix is r×r. 288func (d *DiagDense) reuseAs(r int) { 289 if r == 0 { 290 panic(ErrZeroLength) 291 } 292 if d.IsZero() { 293 d.mat = blas64.Vector{ 294 Inc: 1, 295 Data: use(d.mat.Data, r), 296 } 297 d.mat.N = r 298 return 299 } 300 if r != d.mat.N { 301 panic(ErrShape) 302 } 303} 304 305// IsZero returns whether the receiver is zero-sized. Zero-sized vectors can be the 306// receiver for size-restricted operations. DiagDenses can be zeroed using Reset. 307func (d *DiagDense) IsZero() bool { 308 // It must be the case that d.Dims() returns 309 // zeros in this case. See comment in Reset(). 310 return d.mat.Inc == 0 311} 312