1// Copyright ©2013 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/floats/scalar" 13 "gonum.org/v1/gonum/lapack" 14 "gonum.org/v1/gonum/lapack/lapack64" 15) 16 17// Matrix is the basic matrix interface type. 18type Matrix interface { 19 // Dims returns the dimensions of a Matrix. 20 Dims() (r, c int) 21 22 // At returns the value of a matrix element at row i, column j. 23 // It will panic if i or j are out of bounds for the matrix. 24 At(i, j int) float64 25 26 // T returns the transpose of the Matrix. Whether T returns a copy of the 27 // underlying data is implementation dependent. 28 // This method may be implemented using the Transpose type, which 29 // provides an implicit matrix transpose. 30 T() Matrix 31} 32 33// allMatrix represents the extra set of methods that all mat Matrix types 34// should satisfy. This is used to enforce compile-time consistency between the 35// Dense types, especially helpful when adding new features. 36type allMatrix interface { 37 Reseter 38 IsEmpty() bool 39 Zero() 40} 41 42// denseMatrix represents the extra set of methods that all Dense Matrix types 43// should satisfy. This is used to enforce compile-time consistency between the 44// Dense types, especially helpful when adding new features. 45type denseMatrix interface { 46 DiagView() Diagonal 47 Tracer 48} 49 50var ( 51 _ Matrix = Transpose{} 52 _ Untransposer = Transpose{} 53) 54 55// Transpose is a type for performing an implicit matrix transpose. It implements 56// the Matrix interface, returning values from the transpose of the matrix within. 57type Transpose struct { 58 Matrix Matrix 59} 60 61// At returns the value of the element at row i and column j of the transposed 62// matrix, that is, row j and column i of the Matrix field. 63func (t Transpose) At(i, j int) float64 { 64 return t.Matrix.At(j, i) 65} 66 67// Dims returns the dimensions of the transposed matrix. The number of rows returned 68// is the number of columns in the Matrix field, and the number of columns is 69// the number of rows in the Matrix field. 70func (t Transpose) Dims() (r, c int) { 71 c, r = t.Matrix.Dims() 72 return r, c 73} 74 75// T performs an implicit transpose by returning the Matrix field. 76func (t Transpose) T() Matrix { 77 return t.Matrix 78} 79 80// Untranspose returns the Matrix field. 81func (t Transpose) Untranspose() Matrix { 82 return t.Matrix 83} 84 85// Untransposer is a type that can undo an implicit transpose. 86type Untransposer interface { 87 // Note: This interface is needed to unify all of the Transpose types. In 88 // the mat methods, we need to test if the Matrix has been implicitly 89 // transposed. If this is checked by testing for the specific Transpose type 90 // then the behavior will be different if the user uses T() or TTri() for a 91 // triangular matrix. 92 93 // Untranspose returns the underlying Matrix stored for the implicit transpose. 94 Untranspose() Matrix 95} 96 97// UntransposeBander is a type that can undo an implicit band transpose. 98type UntransposeBander interface { 99 // Untranspose returns the underlying Banded stored for the implicit transpose. 100 UntransposeBand() Banded 101} 102 103// UntransposeTrier is a type that can undo an implicit triangular transpose. 104type UntransposeTrier interface { 105 // Untranspose returns the underlying Triangular stored for the implicit transpose. 106 UntransposeTri() Triangular 107} 108 109// UntransposeTriBander is a type that can undo an implicit triangular banded 110// transpose. 111type UntransposeTriBander interface { 112 // Untranspose returns the underlying Triangular stored for the implicit transpose. 113 UntransposeTriBand() TriBanded 114} 115 116// Mutable is a matrix interface type that allows elements to be altered. 117type Mutable interface { 118 // Set alters the matrix element at row i, column j to v. 119 // It will panic if i or j are out of bounds for the matrix. 120 Set(i, j int, v float64) 121 122 Matrix 123} 124 125// A RowViewer can return a Vector reflecting a row that is backed by the matrix 126// data. The Vector returned will have length equal to the number of columns. 127type RowViewer interface { 128 RowView(i int) Vector 129} 130 131// A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix 132// data. 133type RawRowViewer interface { 134 RawRowView(i int) []float64 135} 136 137// A ColViewer can return a Vector reflecting a column that is backed by the matrix 138// data. The Vector returned will have length equal to the number of rows. 139type ColViewer interface { 140 ColView(j int) Vector 141} 142 143// A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix 144// data. 145type RawColViewer interface { 146 RawColView(j int) []float64 147} 148 149// A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the 150// receiver. The clone operation does not make any restriction on shape and will not cause 151// shadowing. 152type ClonerFrom interface { 153 CloneFrom(a Matrix) 154} 155 156// A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally 157// restricted operation. This is commonly used when the matrix is being used as a workspace 158// or temporary matrix. 159// 160// If the matrix is a view, using Reset may result in data corruption in elements outside 161// the view. Similarly, if the matrix shares backing data with another variable, using 162// Reset may lead to unexpected changes in data values. 163type Reseter interface { 164 Reset() 165} 166 167// A Copier can make a copy of elements of a into the receiver. The submatrix copied 168// starts at row and column 0 and has dimensions equal to the minimum dimensions of 169// the two matrices. The number of row and columns copied is returned. 170// Copy will copy from a source that aliases the receiver unless the source is transposed; 171// an aliasing transpose copy will panic with the exception for a special case when 172// the source data has a unitary increment or stride. 173type Copier interface { 174 Copy(a Matrix) (r, c int) 175} 176 177// A Grower can grow the size of the represented matrix by the given number of rows and columns. 178// Growing beyond the size given by the Caps method will result in the allocation of a new 179// matrix and copying of the elements. If Grow is called with negative increments it will 180// panic with ErrIndexOutOfRange. 181type Grower interface { 182 Caps() (r, c int) 183 Grow(r, c int) Matrix 184} 185 186// A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and 187// k2. 188type BandWidther interface { 189 BandWidth() (k1, k2 int) 190} 191 192// A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction 193// on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data. 194type RawMatrixSetter interface { 195 SetRawMatrix(a blas64.General) 196} 197 198// A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data 199// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not. 200type RawMatrixer interface { 201 RawMatrix() blas64.General 202} 203 204// A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data 205// slice will be reflected in the original matrix, changes to the Inc field will not. 206type RawVectorer interface { 207 RawVector() blas64.Vector 208} 209 210// A NonZeroDoer can call a function for each non-zero element of the receiver. 211// The parameters of the function are the element indices and its value. 212type NonZeroDoer interface { 213 DoNonZero(func(i, j int, v float64)) 214} 215 216// A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver. 217// The parameters of the function are the element indices and its value. 218type RowNonZeroDoer interface { 219 DoRowNonZero(i int, fn func(i, j int, v float64)) 220} 221 222// A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver. 223// The parameters of the function are the element indices and its value. 224type ColNonZeroDoer interface { 225 DoColNonZero(j int, fn func(i, j int, v float64)) 226} 227 228// untranspose untransposes a matrix if applicable. If a is an Untransposer, then 229// untranspose returns the underlying matrix and true. If it is not, then it returns 230// the input matrix and false. 231func untranspose(a Matrix) (Matrix, bool) { 232 if ut, ok := a.(Untransposer); ok { 233 return ut.Untranspose(), true 234 } 235 return a, false 236} 237 238// untransposeExtract returns an untransposed matrix in a built-in matrix type. 239// 240// The untransposed matrix is returned unaltered if it is a built-in matrix type. 241// Otherwise, if it implements a Raw method, an appropriate built-in type value 242// is returned holding the raw matrix value of the input. If neither of these 243// is possible, the untransposed matrix is returned. 244func untransposeExtract(a Matrix) (Matrix, bool) { 245 ut, trans := untranspose(a) 246 switch m := ut.(type) { 247 case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense: 248 return m, trans 249 // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense. 250 case RawSymBander: 251 rsb := m.RawSymBand() 252 if rsb.Uplo != blas.Upper { 253 return ut, trans 254 } 255 var sb SymBandDense 256 sb.SetRawSymBand(rsb) 257 return &sb, trans 258 case RawTriBander: 259 rtb := m.RawTriBand() 260 if rtb.Diag == blas.Unit { 261 return ut, trans 262 } 263 var tb TriBandDense 264 tb.SetRawTriBand(rtb) 265 return &tb, trans 266 case RawBander: 267 var b BandDense 268 b.SetRawBand(m.RawBand()) 269 return &b, trans 270 case RawTriangular: 271 rt := m.RawTriangular() 272 if rt.Diag == blas.Unit { 273 return ut, trans 274 } 275 var t TriDense 276 t.SetRawTriangular(rt) 277 return &t, trans 278 case RawSymmetricer: 279 rs := m.RawSymmetric() 280 if rs.Uplo != blas.Upper { 281 return ut, trans 282 } 283 var s SymDense 284 s.SetRawSymmetric(rs) 285 return &s, trans 286 case RawMatrixer: 287 var d Dense 288 d.SetRawMatrix(m.RawMatrix()) 289 return &d, trans 290 case RawVectorer: 291 var v VecDense 292 v.SetRawVector(m.RawVector()) 293 return &v, trans 294 default: 295 return ut, trans 296 } 297} 298 299// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful. 300// TODO(btracey): Add in fast paths to Row/Col for the other concrete types 301// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.) 302 303// Col copies the elements in the jth column of the matrix into the slice dst. 304// The length of the provided slice must equal the number of rows, unless the 305// slice is nil in which case a new slice is first allocated. 306func Col(dst []float64, j int, a Matrix) []float64 { 307 r, c := a.Dims() 308 if j < 0 || j >= c { 309 panic(ErrColAccess) 310 } 311 if dst == nil { 312 dst = make([]float64, r) 313 } else { 314 if len(dst) != r { 315 panic(ErrColLength) 316 } 317 } 318 aU, aTrans := untranspose(a) 319 if rm, ok := aU.(RawMatrixer); ok { 320 m := rm.RawMatrix() 321 if aTrans { 322 copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols]) 323 return dst 324 } 325 blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]}, 326 blas64.Vector{N: r, Inc: 1, Data: dst}, 327 ) 328 return dst 329 } 330 for i := 0; i < r; i++ { 331 dst[i] = a.At(i, j) 332 } 333 return dst 334} 335 336// Row copies the elements in the ith row of the matrix into the slice dst. 337// The length of the provided slice must equal the number of columns, unless the 338// slice is nil in which case a new slice is first allocated. 339func Row(dst []float64, i int, a Matrix) []float64 { 340 r, c := a.Dims() 341 if i < 0 || i >= r { 342 panic(ErrColAccess) 343 } 344 if dst == nil { 345 dst = make([]float64, c) 346 } else { 347 if len(dst) != c { 348 panic(ErrRowLength) 349 } 350 } 351 aU, aTrans := untranspose(a) 352 if rm, ok := aU.(RawMatrixer); ok { 353 m := rm.RawMatrix() 354 if aTrans { 355 blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]}, 356 blas64.Vector{N: c, Inc: 1, Data: dst}, 357 ) 358 return dst 359 } 360 copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols]) 361 return dst 362 } 363 for j := 0; j < c; j++ { 364 dst[j] = a.At(i, j) 365 } 366 return dst 367} 368 369// Cond returns the condition number of the given matrix under the given norm. 370// The condition number must be based on the 1-norm, 2-norm or ∞-norm. 371// Cond will panic with matrix.ErrShape if the matrix has zero size. 372// 373// BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices 374// is inaccurate, although is typically the right order of magnitude. See 375// https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will 376// change with the resolution of this bug, the result from Cond will match the 377// condition number used internally. 378func Cond(a Matrix, norm float64) float64 { 379 m, n := a.Dims() 380 if m == 0 || n == 0 { 381 panic(ErrShape) 382 } 383 var lnorm lapack.MatrixNorm 384 switch norm { 385 default: 386 panic("mat: bad norm value") 387 case 1: 388 lnorm = lapack.MaxColumnSum 389 case 2: 390 var svd SVD 391 ok := svd.Factorize(a, SVDNone) 392 if !ok { 393 return math.Inf(1) 394 } 395 return svd.Cond() 396 case math.Inf(1): 397 lnorm = lapack.MaxRowSum 398 } 399 400 if m == n { 401 // Use the LU decomposition to compute the condition number. 402 var lu LU 403 lu.factorize(a, lnorm) 404 return lu.Cond() 405 } 406 if m > n { 407 // Use the QR factorization to compute the condition number. 408 var qr QR 409 qr.factorize(a, lnorm) 410 return qr.Cond() 411 } 412 // Use the LQ factorization to compute the condition number. 413 var lq LQ 414 lq.factorize(a, lnorm) 415 return lq.Cond() 416} 417 418// Det returns the determinant of the matrix a. In many expressions using LogDet 419// will be more numerically stable. 420func Det(a Matrix) float64 { 421 det, sign := LogDet(a) 422 return math.Exp(det) * sign 423} 424 425// Dot returns the sum of the element-wise product of a and b. 426// Dot panics if the matrix sizes are unequal. 427func Dot(a, b Vector) float64 { 428 la := a.Len() 429 lb := b.Len() 430 if la != lb { 431 panic(ErrShape) 432 } 433 if arv, ok := a.(RawVectorer); ok { 434 if brv, ok := b.(RawVectorer); ok { 435 return blas64.Dot(arv.RawVector(), brv.RawVector()) 436 } 437 } 438 var sum float64 439 for i := 0; i < la; i++ { 440 sum += a.At(i, 0) * b.At(i, 0) 441 } 442 return sum 443} 444 445// Equal returns whether the matrices a and b have the same size 446// and are element-wise equal. 447func Equal(a, b Matrix) bool { 448 ar, ac := a.Dims() 449 br, bc := b.Dims() 450 if ar != br || ac != bc { 451 return false 452 } 453 aU, aTrans := untranspose(a) 454 bU, bTrans := untranspose(b) 455 if rma, ok := aU.(RawMatrixer); ok { 456 if rmb, ok := bU.(RawMatrixer); ok { 457 ra := rma.RawMatrix() 458 rb := rmb.RawMatrix() 459 if aTrans == bTrans { 460 for i := 0; i < ra.Rows; i++ { 461 for j := 0; j < ra.Cols; j++ { 462 if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] { 463 return false 464 } 465 } 466 } 467 return true 468 } 469 for i := 0; i < ra.Rows; i++ { 470 for j := 0; j < ra.Cols; j++ { 471 if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] { 472 return false 473 } 474 } 475 } 476 return true 477 } 478 } 479 if rma, ok := aU.(RawSymmetricer); ok { 480 if rmb, ok := bU.(RawSymmetricer); ok { 481 ra := rma.RawSymmetric() 482 rb := rmb.RawSymmetric() 483 // Symmetric matrices are always upper and equal to their transpose. 484 for i := 0; i < ra.N; i++ { 485 for j := i; j < ra.N; j++ { 486 if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] { 487 return false 488 } 489 } 490 } 491 return true 492 } 493 } 494 if ra, ok := aU.(*VecDense); ok { 495 if rb, ok := bU.(*VecDense); ok { 496 // If the raw vectors are the same length they must either both be 497 // transposed or both not transposed (or have length 1). 498 for i := 0; i < ra.mat.N; i++ { 499 if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] { 500 return false 501 } 502 } 503 return true 504 } 505 } 506 for i := 0; i < ar; i++ { 507 for j := 0; j < ac; j++ { 508 if a.At(i, j) != b.At(i, j) { 509 return false 510 } 511 } 512 } 513 return true 514} 515 516// EqualApprox returns whether the matrices a and b have the same size and contain all equal 517// elements with tolerance for element-wise equality specified by epsilon. Matrices 518// with non-equal shapes are not equal. 519func EqualApprox(a, b Matrix, epsilon float64) bool { 520 ar, ac := a.Dims() 521 br, bc := b.Dims() 522 if ar != br || ac != bc { 523 return false 524 } 525 aU, aTrans := untranspose(a) 526 bU, bTrans := untranspose(b) 527 if rma, ok := aU.(RawMatrixer); ok { 528 if rmb, ok := bU.(RawMatrixer); ok { 529 ra := rma.RawMatrix() 530 rb := rmb.RawMatrix() 531 if aTrans == bTrans { 532 for i := 0; i < ra.Rows; i++ { 533 for j := 0; j < ra.Cols; j++ { 534 if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) { 535 return false 536 } 537 } 538 } 539 return true 540 } 541 for i := 0; i < ra.Rows; i++ { 542 for j := 0; j < ra.Cols; j++ { 543 if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) { 544 return false 545 } 546 } 547 } 548 return true 549 } 550 } 551 if rma, ok := aU.(RawSymmetricer); ok { 552 if rmb, ok := bU.(RawSymmetricer); ok { 553 ra := rma.RawSymmetric() 554 rb := rmb.RawSymmetric() 555 // Symmetric matrices are always upper and equal to their transpose. 556 for i := 0; i < ra.N; i++ { 557 for j := i; j < ra.N; j++ { 558 if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) { 559 return false 560 } 561 } 562 } 563 return true 564 } 565 } 566 if ra, ok := aU.(*VecDense); ok { 567 if rb, ok := bU.(*VecDense); ok { 568 // If the raw vectors are the same length they must either both be 569 // transposed or both not transposed (or have length 1). 570 for i := 0; i < ra.mat.N; i++ { 571 if !scalar.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) { 572 return false 573 } 574 } 575 return true 576 } 577 } 578 for i := 0; i < ar; i++ { 579 for j := 0; j < ac; j++ { 580 if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) { 581 return false 582 } 583 } 584 } 585 return true 586} 587 588// LogDet returns the log of the determinant and the sign of the determinant 589// for the matrix that has been factorized. Numerical stability in product and 590// division expressions is generally improved by working in log space. 591func LogDet(a Matrix) (det float64, sign float64) { 592 // TODO(btracey): Add specialized routines for TriDense, etc. 593 var lu LU 594 lu.Factorize(a) 595 return lu.LogDet() 596} 597 598// Max returns the largest element value of the matrix A. 599// Max will panic with matrix.ErrShape if the matrix has zero size. 600func Max(a Matrix) float64 { 601 r, c := a.Dims() 602 if r == 0 || c == 0 { 603 panic(ErrShape) 604 } 605 // Max(A) = Max(Aᵀ) 606 aU, _ := untranspose(a) 607 switch m := aU.(type) { 608 case RawMatrixer: 609 rm := m.RawMatrix() 610 max := math.Inf(-1) 611 for i := 0; i < rm.Rows; i++ { 612 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 613 if v > max { 614 max = v 615 } 616 } 617 } 618 return max 619 case RawTriangular: 620 rm := m.RawTriangular() 621 // The max of a triangular is at least 0 unless the size is 1. 622 if rm.N == 1 { 623 return rm.Data[0] 624 } 625 max := 0.0 626 if rm.Uplo == blas.Upper { 627 for i := 0; i < rm.N; i++ { 628 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 629 if v > max { 630 max = v 631 } 632 } 633 } 634 return max 635 } 636 for i := 0; i < rm.N; i++ { 637 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] { 638 if v > max { 639 max = v 640 } 641 } 642 } 643 return max 644 case RawSymmetricer: 645 rm := m.RawSymmetric() 646 if rm.Uplo != blas.Upper { 647 panic(badSymTriangle) 648 } 649 max := math.Inf(-1) 650 for i := 0; i < rm.N; i++ { 651 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 652 if v > max { 653 max = v 654 } 655 } 656 } 657 return max 658 default: 659 r, c := aU.Dims() 660 max := math.Inf(-1) 661 for i := 0; i < r; i++ { 662 for j := 0; j < c; j++ { 663 v := aU.At(i, j) 664 if v > max { 665 max = v 666 } 667 } 668 } 669 return max 670 } 671} 672 673// Min returns the smallest element value of the matrix A. 674// Min will panic with matrix.ErrShape if the matrix has zero size. 675func Min(a Matrix) float64 { 676 r, c := a.Dims() 677 if r == 0 || c == 0 { 678 panic(ErrShape) 679 } 680 // Min(A) = Min(Aᵀ) 681 aU, _ := untranspose(a) 682 switch m := aU.(type) { 683 case RawMatrixer: 684 rm := m.RawMatrix() 685 min := math.Inf(1) 686 for i := 0; i < rm.Rows; i++ { 687 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 688 if v < min { 689 min = v 690 } 691 } 692 } 693 return min 694 case RawTriangular: 695 rm := m.RawTriangular() 696 // The min of a triangular is at most 0 unless the size is 1. 697 if rm.N == 1 { 698 return rm.Data[0] 699 } 700 min := 0.0 701 if rm.Uplo == blas.Upper { 702 for i := 0; i < rm.N; i++ { 703 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 704 if v < min { 705 min = v 706 } 707 } 708 } 709 return min 710 } 711 for i := 0; i < rm.N; i++ { 712 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] { 713 if v < min { 714 min = v 715 } 716 } 717 } 718 return min 719 case RawSymmetricer: 720 rm := m.RawSymmetric() 721 if rm.Uplo != blas.Upper { 722 panic(badSymTriangle) 723 } 724 min := math.Inf(1) 725 for i := 0; i < rm.N; i++ { 726 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 727 if v < min { 728 min = v 729 } 730 } 731 } 732 return min 733 default: 734 r, c := aU.Dims() 735 min := math.Inf(1) 736 for i := 0; i < r; i++ { 737 for j := 0; j < c; j++ { 738 v := aU.At(i, j) 739 if v < min { 740 min = v 741 } 742 } 743 } 744 return min 745 } 746} 747 748// Norm returns the specified norm of the matrix A. Valid norms are: 749// 1 - The maximum absolute column sum 750// 2 - The Frobenius norm, the square root of the sum of the squares of the elements 751// Inf - The maximum absolute row sum 752// 753// Norm will panic with ErrNormOrder if an illegal norm order is specified and 754// with ErrShape if the matrix has zero size. 755func Norm(a Matrix, norm float64) float64 { 756 r, c := a.Dims() 757 if r == 0 || c == 0 { 758 panic(ErrShape) 759 } 760 aU, aTrans := untranspose(a) 761 var work []float64 762 switch rma := aU.(type) { 763 case RawMatrixer: 764 rm := rma.RawMatrix() 765 n := normLapack(norm, aTrans) 766 if n == lapack.MaxColumnSum { 767 work = getFloats(rm.Cols, false) 768 defer putFloats(work) 769 } 770 return lapack64.Lange(n, rm, work) 771 case RawTriangular: 772 rm := rma.RawTriangular() 773 n := normLapack(norm, aTrans) 774 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum { 775 work = getFloats(rm.N, false) 776 defer putFloats(work) 777 } 778 return lapack64.Lantr(n, rm, work) 779 case RawSymmetricer: 780 rm := rma.RawSymmetric() 781 n := normLapack(norm, aTrans) 782 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum { 783 work = getFloats(rm.N, false) 784 defer putFloats(work) 785 } 786 return lapack64.Lansy(n, rm, work) 787 case *VecDense: 788 rv := rma.RawVector() 789 switch norm { 790 default: 791 panic(ErrNormOrder) 792 case 1: 793 if aTrans { 794 imax := blas64.Iamax(rv) 795 return math.Abs(rma.At(imax, 0)) 796 } 797 return blas64.Asum(rv) 798 case 2: 799 return blas64.Nrm2(rv) 800 case math.Inf(1): 801 if aTrans { 802 return blas64.Asum(rv) 803 } 804 imax := blas64.Iamax(rv) 805 return math.Abs(rma.At(imax, 0)) 806 } 807 } 808 switch norm { 809 default: 810 panic(ErrNormOrder) 811 case 1: 812 var max float64 813 for j := 0; j < c; j++ { 814 var sum float64 815 for i := 0; i < r; i++ { 816 sum += math.Abs(a.At(i, j)) 817 } 818 if sum > max { 819 max = sum 820 } 821 } 822 return max 823 case 2: 824 var sum float64 825 for i := 0; i < r; i++ { 826 for j := 0; j < c; j++ { 827 v := a.At(i, j) 828 sum += v * v 829 } 830 } 831 return math.Sqrt(sum) 832 case math.Inf(1): 833 var max float64 834 for i := 0; i < r; i++ { 835 var sum float64 836 for j := 0; j < c; j++ { 837 sum += math.Abs(a.At(i, j)) 838 } 839 if sum > max { 840 max = sum 841 } 842 } 843 return max 844 } 845} 846 847// normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm. 848func normLapack(norm float64, aTrans bool) lapack.MatrixNorm { 849 switch norm { 850 case 1: 851 n := lapack.MaxColumnSum 852 if aTrans { 853 n = lapack.MaxRowSum 854 } 855 return n 856 case 2: 857 return lapack.Frobenius 858 case math.Inf(1): 859 n := lapack.MaxRowSum 860 if aTrans { 861 n = lapack.MaxColumnSum 862 } 863 return n 864 default: 865 panic(ErrNormOrder) 866 } 867} 868 869// Sum returns the sum of the elements of the matrix. 870func Sum(a Matrix) float64 { 871 872 var sum float64 873 aU, _ := untranspose(a) 874 switch rma := aU.(type) { 875 case RawSymmetricer: 876 rm := rma.RawSymmetric() 877 for i := 0; i < rm.N; i++ { 878 // Diagonals count once while off-diagonals count twice. 879 sum += rm.Data[i*rm.Stride+i] 880 var s float64 881 for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] { 882 s += v 883 } 884 sum += 2 * s 885 } 886 return sum 887 case RawTriangular: 888 rm := rma.RawTriangular() 889 var startIdx, endIdx int 890 for i := 0; i < rm.N; i++ { 891 // Start and end index for this triangle-row. 892 switch rm.Uplo { 893 case blas.Upper: 894 startIdx = i 895 endIdx = rm.N 896 case blas.Lower: 897 startIdx = 0 898 endIdx = i + 1 899 default: 900 panic(badTriangle) 901 } 902 for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] { 903 sum += v 904 } 905 } 906 return sum 907 case RawMatrixer: 908 rm := rma.RawMatrix() 909 for i := 0; i < rm.Rows; i++ { 910 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 911 sum += v 912 } 913 } 914 return sum 915 case *VecDense: 916 rm := rma.RawVector() 917 for i := 0; i < rm.N; i++ { 918 sum += rm.Data[i*rm.Inc] 919 } 920 return sum 921 default: 922 r, c := a.Dims() 923 for i := 0; i < r; i++ { 924 for j := 0; j < c; j++ { 925 sum += a.At(i, j) 926 } 927 } 928 return sum 929 } 930} 931 932// A Tracer can compute the trace of the matrix. Trace must panic if the 933// matrix is not square. 934type Tracer interface { 935 Trace() float64 936} 937 938// Trace returns the trace of the matrix. Trace will panic if the 939// matrix is not square. If a is a Tracer, its Trace method will be 940// used to calculate the matrix trace. 941func Trace(a Matrix) float64 { 942 m, _ := untransposeExtract(a) 943 if t, ok := m.(Tracer); ok { 944 return t.Trace() 945 } 946 r, c := a.Dims() 947 if r != c { 948 panic(ErrSquare) 949 } 950 var v float64 951 for i := 0; i < r; i++ { 952 v += a.At(i, i) 953 } 954 return v 955} 956 957func min(a, b int) int { 958 if a < b { 959 return a 960 } 961 return b 962} 963 964func max(a, b int) int { 965 if a > b { 966 return a 967 } 968 return b 969} 970 971// use returns a float64 slice with l elements, using f if it 972// has the necessary capacity, otherwise creating a new slice. 973func use(f []float64, l int) []float64 { 974 if l <= cap(f) { 975 return f[:l] 976 } 977 return make([]float64, l) 978} 979 980// useZeroed returns a float64 slice with l elements, using f if it 981// has the necessary capacity, otherwise creating a new slice. The 982// elements of the returned slice are guaranteed to be zero. 983func useZeroed(f []float64, l int) []float64 { 984 if l <= cap(f) { 985 f = f[:l] 986 zero(f) 987 return f 988 } 989 return make([]float64, l) 990} 991 992// zero zeros the given slice's elements. 993func zero(f []float64) { 994 for i := range f { 995 f[i] = 0 996 } 997} 998 999// useInt returns an int slice with l elements, using i if it 1000// has the necessary capacity, otherwise creating a new slice. 1001func useInt(i []int, l int) []int { 1002 if l <= cap(i) { 1003 return i[:l] 1004 } 1005 return make([]int, l) 1006} 1007