1// Copyright ©2013 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/floats" 13 "gonum.org/v1/gonum/lapack" 14 "gonum.org/v1/gonum/lapack/lapack64" 15) 16 17// Matrix is the basic matrix interface type. 18type Matrix interface { 19 // Dims returns the dimensions of a Matrix. 20 Dims() (r, c int) 21 22 // At returns the value of a matrix element at row i, column j. 23 // It will panic if i or j are out of bounds for the matrix. 24 At(i, j int) float64 25 26 // T returns the transpose of the Matrix. Whether T returns a copy of the 27 // underlying data is implementation dependent. 28 // This method may be implemented using the Transpose type, which 29 // provides an implicit matrix transpose. 30 T() Matrix 31} 32 33var ( 34 _ Matrix = Transpose{} 35 _ Untransposer = Transpose{} 36) 37 38// Transpose is a type for performing an implicit matrix transpose. It implements 39// the Matrix interface, returning values from the transpose of the matrix within. 40type Transpose struct { 41 Matrix Matrix 42} 43 44// At returns the value of the element at row i and column j of the transposed 45// matrix, that is, row j and column i of the Matrix field. 46func (t Transpose) At(i, j int) float64 { 47 return t.Matrix.At(j, i) 48} 49 50// Dims returns the dimensions of the transposed matrix. The number of rows returned 51// is the number of columns in the Matrix field, and the number of columns is 52// the number of rows in the Matrix field. 53func (t Transpose) Dims() (r, c int) { 54 c, r = t.Matrix.Dims() 55 return r, c 56} 57 58// T performs an implicit transpose by returning the Matrix field. 59func (t Transpose) T() Matrix { 60 return t.Matrix 61} 62 63// Untranspose returns the Matrix field. 64func (t Transpose) Untranspose() Matrix { 65 return t.Matrix 66} 67 68// Untransposer is a type that can undo an implicit transpose. 69type Untransposer interface { 70 // Note: This interface is needed to unify all of the Transpose types. In 71 // the mat methods, we need to test if the Matrix has been implicitly 72 // transposed. If this is checked by testing for the specific Transpose type 73 // then the behavior will be different if the user uses T() or TTri() for a 74 // triangular matrix. 75 76 // Untranspose returns the underlying Matrix stored for the implicit transpose. 77 Untranspose() Matrix 78} 79 80// UntransposeBander is a type that can undo an implicit band transpose. 81type UntransposeBander interface { 82 // Untranspose returns the underlying Banded stored for the implicit transpose. 83 UntransposeBand() Banded 84} 85 86// UntransposeTrier is a type that can undo an implicit triangular transpose. 87type UntransposeTrier interface { 88 // Untranspose returns the underlying Triangular stored for the implicit transpose. 89 UntransposeTri() Triangular 90} 91 92// UntransposeTriBander is a type that can undo an implicit triangular banded 93// transpose. 94type UntransposeTriBander interface { 95 // Untranspose returns the underlying Triangular stored for the implicit transpose. 96 UntransposeTriBand() TriBanded 97} 98 99// Mutable is a matrix interface type that allows elements to be altered. 100type Mutable interface { 101 // Set alters the matrix element at row i, column j to v. 102 // It will panic if i or j are out of bounds for the matrix. 103 Set(i, j int, v float64) 104 105 Matrix 106} 107 108// A RowViewer can return a Vector reflecting a row that is backed by the matrix 109// data. The Vector returned will have length equal to the number of columns. 110type RowViewer interface { 111 RowView(i int) Vector 112} 113 114// A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix 115// data. 116type RawRowViewer interface { 117 RawRowView(i int) []float64 118} 119 120// A ColViewer can return a Vector reflecting a column that is backed by the matrix 121// data. The Vector returned will have length equal to the number of rows. 122type ColViewer interface { 123 ColView(j int) Vector 124} 125 126// A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix 127// data. 128type RawColViewer interface { 129 RawColView(j int) []float64 130} 131 132// A Cloner can make a copy of a into the receiver, overwriting the previous value of the 133// receiver. The clone operation does not make any restriction on shape and will not cause 134// shadowing. 135type Cloner interface { 136 Clone(a Matrix) 137} 138 139// A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally 140// restricted operation. This is commonly used when the matrix is being used as a workspace 141// or temporary matrix. 142// 143// If the matrix is a view, using the reset matrix may result in data corruption in elements 144// outside the view. 145type Reseter interface { 146 Reset() 147} 148 149// A Copier can make a copy of elements of a into the receiver. The submatrix copied 150// starts at row and column 0 and has dimensions equal to the minimum dimensions of 151// the two matrices. The number of row and columns copied is returned. 152// Copy will copy from a source that aliases the receiver unless the source is transposed; 153// an aliasing transpose copy will panic with the exception for a special case when 154// the source data has a unitary increment or stride. 155type Copier interface { 156 Copy(a Matrix) (r, c int) 157} 158 159// A Grower can grow the size of the represented matrix by the given number of rows and columns. 160// Growing beyond the size given by the Caps method will result in the allocation of a new 161// matrix and copying of the elements. If Grow is called with negative increments it will 162// panic with ErrIndexOutOfRange. 163type Grower interface { 164 Caps() (r, c int) 165 Grow(r, c int) Matrix 166} 167 168// A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and 169// k2. 170type BandWidther interface { 171 BandWidth() (k1, k2 int) 172} 173 174// A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction 175// on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data. 176type RawMatrixSetter interface { 177 SetRawMatrix(a blas64.General) 178} 179 180// A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data 181// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not. 182type RawMatrixer interface { 183 RawMatrix() blas64.General 184} 185 186// A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data 187// slice will be reflected in the original matrix, changes to the Inc field will not. 188type RawVectorer interface { 189 RawVector() blas64.Vector 190} 191 192// A NonZeroDoer can call a function for each non-zero element of the receiver. 193// The parameters of the function are the element indices and its value. 194type NonZeroDoer interface { 195 DoNonZero(func(i, j int, v float64)) 196} 197 198// A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver. 199// The parameters of the function are the element indices and its value. 200type RowNonZeroDoer interface { 201 DoRowNonZero(i int, fn func(i, j int, v float64)) 202} 203 204// A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver. 205// The parameters of the function are the element indices and its value. 206type ColNonZeroDoer interface { 207 DoColNonZero(j int, fn func(i, j int, v float64)) 208} 209 210// untranspose untransposes a matrix if applicable. If a is an Untransposer, then 211// untranspose returns the underlying matrix and true. If it is not, then it returns 212// the input matrix and false. 213func untranspose(a Matrix) (Matrix, bool) { 214 if ut, ok := a.(Untransposer); ok { 215 return ut.Untranspose(), true 216 } 217 return a, false 218} 219 220// untransposeExtract returns an untransposed matrix in a built-in matrix type. 221// 222// The untransposed matrix is returned unaltered if it is a built-in matrix type. 223// Otherwise, if it implements a Raw method, an appropriate built-in type value 224// is returned holding the raw matrix value of the input. If neither of these 225// is possible, the untransposed matrix is returned. 226func untransposeExtract(a Matrix) (Matrix, bool) { 227 ut, trans := untranspose(a) 228 switch m := ut.(type) { 229 case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense: 230 return m, trans 231 // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense. 232 case RawSymBander: 233 rsb := m.RawSymBand() 234 if rsb.Uplo != blas.Upper { 235 return ut, trans 236 } 237 var sb SymBandDense 238 sb.SetRawSymBand(rsb) 239 return &sb, trans 240 case RawTriBander: 241 rtb := m.RawTriBand() 242 if rtb.Diag == blas.Unit { 243 return ut, trans 244 } 245 var tb TriBandDense 246 tb.SetRawTriBand(rtb) 247 return &tb, trans 248 case RawBander: 249 var b BandDense 250 b.SetRawBand(m.RawBand()) 251 return &b, trans 252 case RawTriangular: 253 rt := m.RawTriangular() 254 if rt.Diag == blas.Unit { 255 return ut, trans 256 } 257 var t TriDense 258 t.SetRawTriangular(rt) 259 return &t, trans 260 case RawSymmetricer: 261 rs := m.RawSymmetric() 262 if rs.Uplo != blas.Upper { 263 return ut, trans 264 } 265 var s SymDense 266 s.SetRawSymmetric(rs) 267 return &s, trans 268 case RawMatrixer: 269 var d Dense 270 d.SetRawMatrix(m.RawMatrix()) 271 return &d, trans 272 default: 273 return ut, trans 274 } 275} 276 277// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful. 278// TODO(btracey): Add in fast paths to Row/Col for the other concrete types 279// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.) 280 281// Col copies the elements in the jth column of the matrix into the slice dst. 282// The length of the provided slice must equal the number of rows, unless the 283// slice is nil in which case a new slice is first allocated. 284func Col(dst []float64, j int, a Matrix) []float64 { 285 r, c := a.Dims() 286 if j < 0 || j >= c { 287 panic(ErrColAccess) 288 } 289 if dst == nil { 290 dst = make([]float64, r) 291 } else { 292 if len(dst) != r { 293 panic(ErrColLength) 294 } 295 } 296 aU, aTrans := untranspose(a) 297 if rm, ok := aU.(RawMatrixer); ok { 298 m := rm.RawMatrix() 299 if aTrans { 300 copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols]) 301 return dst 302 } 303 blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]}, 304 blas64.Vector{N: r, Inc: 1, Data: dst}, 305 ) 306 return dst 307 } 308 for i := 0; i < r; i++ { 309 dst[i] = a.At(i, j) 310 } 311 return dst 312} 313 314// Row copies the elements in the ith row of the matrix into the slice dst. 315// The length of the provided slice must equal the number of columns, unless the 316// slice is nil in which case a new slice is first allocated. 317func Row(dst []float64, i int, a Matrix) []float64 { 318 r, c := a.Dims() 319 if i < 0 || i >= r { 320 panic(ErrColAccess) 321 } 322 if dst == nil { 323 dst = make([]float64, c) 324 } else { 325 if len(dst) != c { 326 panic(ErrRowLength) 327 } 328 } 329 aU, aTrans := untranspose(a) 330 if rm, ok := aU.(RawMatrixer); ok { 331 m := rm.RawMatrix() 332 if aTrans { 333 blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]}, 334 blas64.Vector{N: c, Inc: 1, Data: dst}, 335 ) 336 return dst 337 } 338 copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols]) 339 return dst 340 } 341 for j := 0; j < c; j++ { 342 dst[j] = a.At(i, j) 343 } 344 return dst 345} 346 347// Cond returns the condition number of the given matrix under the given norm. 348// The condition number must be based on the 1-norm, 2-norm or ∞-norm. 349// Cond will panic with matrix.ErrShape if the matrix has zero size. 350// 351// BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices 352// is innacurate, although is typically the right order of magnitude. See 353// https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will 354// change with the resolution of this bug, the result from Cond will match the 355// condition number used internally. 356func Cond(a Matrix, norm float64) float64 { 357 m, n := a.Dims() 358 if m == 0 || n == 0 { 359 panic(ErrShape) 360 } 361 var lnorm lapack.MatrixNorm 362 switch norm { 363 default: 364 panic("mat: bad norm value") 365 case 1: 366 lnorm = lapack.MaxColumnSum 367 case 2: 368 var svd SVD 369 ok := svd.Factorize(a, SVDNone) 370 if !ok { 371 return math.Inf(1) 372 } 373 return svd.Cond() 374 case math.Inf(1): 375 lnorm = lapack.MaxRowSum 376 } 377 378 if m == n { 379 // Use the LU decomposition to compute the condition number. 380 var lu LU 381 lu.factorize(a, lnorm) 382 return lu.Cond() 383 } 384 if m > n { 385 // Use the QR factorization to compute the condition number. 386 var qr QR 387 qr.factorize(a, lnorm) 388 return qr.Cond() 389 } 390 // Use the LQ factorization to compute the condition number. 391 var lq LQ 392 lq.factorize(a, lnorm) 393 return lq.Cond() 394} 395 396// Det returns the determinant of the matrix a. In many expressions using LogDet 397// will be more numerically stable. 398func Det(a Matrix) float64 { 399 det, sign := LogDet(a) 400 return math.Exp(det) * sign 401} 402 403// Dot returns the sum of the element-wise product of a and b. 404// Dot panics if the matrix sizes are unequal. 405func Dot(a, b Vector) float64 { 406 la := a.Len() 407 lb := b.Len() 408 if la != lb { 409 panic(ErrShape) 410 } 411 if arv, ok := a.(RawVectorer); ok { 412 if brv, ok := b.(RawVectorer); ok { 413 return blas64.Dot(arv.RawVector(), brv.RawVector()) 414 } 415 } 416 var sum float64 417 for i := 0; i < la; i++ { 418 sum += a.At(i, 0) * b.At(i, 0) 419 } 420 return sum 421} 422 423// Equal returns whether the matrices a and b have the same size 424// and are element-wise equal. 425func Equal(a, b Matrix) bool { 426 ar, ac := a.Dims() 427 br, bc := b.Dims() 428 if ar != br || ac != bc { 429 return false 430 } 431 aU, aTrans := untranspose(a) 432 bU, bTrans := untranspose(b) 433 if rma, ok := aU.(RawMatrixer); ok { 434 if rmb, ok := bU.(RawMatrixer); ok { 435 ra := rma.RawMatrix() 436 rb := rmb.RawMatrix() 437 if aTrans == bTrans { 438 for i := 0; i < ra.Rows; i++ { 439 for j := 0; j < ra.Cols; j++ { 440 if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] { 441 return false 442 } 443 } 444 } 445 return true 446 } 447 for i := 0; i < ra.Rows; i++ { 448 for j := 0; j < ra.Cols; j++ { 449 if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] { 450 return false 451 } 452 } 453 } 454 return true 455 } 456 } 457 if rma, ok := aU.(RawSymmetricer); ok { 458 if rmb, ok := bU.(RawSymmetricer); ok { 459 ra := rma.RawSymmetric() 460 rb := rmb.RawSymmetric() 461 // Symmetric matrices are always upper and equal to their transpose. 462 for i := 0; i < ra.N; i++ { 463 for j := i; j < ra.N; j++ { 464 if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] { 465 return false 466 } 467 } 468 } 469 return true 470 } 471 } 472 if ra, ok := aU.(*VecDense); ok { 473 if rb, ok := bU.(*VecDense); ok { 474 // If the raw vectors are the same length they must either both be 475 // transposed or both not transposed (or have length 1). 476 for i := 0; i < ra.mat.N; i++ { 477 if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] { 478 return false 479 } 480 } 481 return true 482 } 483 } 484 for i := 0; i < ar; i++ { 485 for j := 0; j < ac; j++ { 486 if a.At(i, j) != b.At(i, j) { 487 return false 488 } 489 } 490 } 491 return true 492} 493 494// EqualApprox returns whether the matrices a and b have the same size and contain all equal 495// elements with tolerance for element-wise equality specified by epsilon. Matrices 496// with non-equal shapes are not equal. 497func EqualApprox(a, b Matrix, epsilon float64) bool { 498 ar, ac := a.Dims() 499 br, bc := b.Dims() 500 if ar != br || ac != bc { 501 return false 502 } 503 aU, aTrans := untranspose(a) 504 bU, bTrans := untranspose(b) 505 if rma, ok := aU.(RawMatrixer); ok { 506 if rmb, ok := bU.(RawMatrixer); ok { 507 ra := rma.RawMatrix() 508 rb := rmb.RawMatrix() 509 if aTrans == bTrans { 510 for i := 0; i < ra.Rows; i++ { 511 for j := 0; j < ra.Cols; j++ { 512 if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) { 513 return false 514 } 515 } 516 } 517 return true 518 } 519 for i := 0; i < ra.Rows; i++ { 520 for j := 0; j < ra.Cols; j++ { 521 if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) { 522 return false 523 } 524 } 525 } 526 return true 527 } 528 } 529 if rma, ok := aU.(RawSymmetricer); ok { 530 if rmb, ok := bU.(RawSymmetricer); ok { 531 ra := rma.RawSymmetric() 532 rb := rmb.RawSymmetric() 533 // Symmetric matrices are always upper and equal to their transpose. 534 for i := 0; i < ra.N; i++ { 535 for j := i; j < ra.N; j++ { 536 if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) { 537 return false 538 } 539 } 540 } 541 return true 542 } 543 } 544 if ra, ok := aU.(*VecDense); ok { 545 if rb, ok := bU.(*VecDense); ok { 546 // If the raw vectors are the same length they must either both be 547 // transposed or both not transposed (or have length 1). 548 for i := 0; i < ra.mat.N; i++ { 549 if !floats.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) { 550 return false 551 } 552 } 553 return true 554 } 555 } 556 for i := 0; i < ar; i++ { 557 for j := 0; j < ac; j++ { 558 if !floats.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) { 559 return false 560 } 561 } 562 } 563 return true 564} 565 566// LogDet returns the log of the determinant and the sign of the determinant 567// for the matrix that has been factorized. Numerical stability in product and 568// division expressions is generally improved by working in log space. 569func LogDet(a Matrix) (det float64, sign float64) { 570 // TODO(btracey): Add specialized routines for TriDense, etc. 571 var lu LU 572 lu.Factorize(a) 573 return lu.LogDet() 574} 575 576// Max returns the largest element value of the matrix A. 577// Max will panic with matrix.ErrShape if the matrix has zero size. 578func Max(a Matrix) float64 { 579 r, c := a.Dims() 580 if r == 0 || c == 0 { 581 panic(ErrShape) 582 } 583 // Max(A) = Max(A^T) 584 aU, _ := untranspose(a) 585 switch m := aU.(type) { 586 case RawMatrixer: 587 rm := m.RawMatrix() 588 max := math.Inf(-1) 589 for i := 0; i < rm.Rows; i++ { 590 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 591 if v > max { 592 max = v 593 } 594 } 595 } 596 return max 597 case RawTriangular: 598 rm := m.RawTriangular() 599 // The max of a triangular is at least 0 unless the size is 1. 600 if rm.N == 1 { 601 return rm.Data[0] 602 } 603 max := 0.0 604 if rm.Uplo == blas.Upper { 605 for i := 0; i < rm.N; i++ { 606 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 607 if v > max { 608 max = v 609 } 610 } 611 } 612 return max 613 } 614 for i := 0; i < rm.N; i++ { 615 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] { 616 if v > max { 617 max = v 618 } 619 } 620 } 621 return max 622 case RawSymmetricer: 623 rm := m.RawSymmetric() 624 if rm.Uplo != blas.Upper { 625 panic(badSymTriangle) 626 } 627 max := math.Inf(-1) 628 for i := 0; i < rm.N; i++ { 629 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 630 if v > max { 631 max = v 632 } 633 } 634 } 635 return max 636 default: 637 r, c := aU.Dims() 638 max := math.Inf(-1) 639 for i := 0; i < r; i++ { 640 for j := 0; j < c; j++ { 641 v := aU.At(i, j) 642 if v > max { 643 max = v 644 } 645 } 646 } 647 return max 648 } 649} 650 651// Min returns the smallest element value of the matrix A. 652// Min will panic with matrix.ErrShape if the matrix has zero size. 653func Min(a Matrix) float64 { 654 r, c := a.Dims() 655 if r == 0 || c == 0 { 656 panic(ErrShape) 657 } 658 // Min(A) = Min(A^T) 659 aU, _ := untranspose(a) 660 switch m := aU.(type) { 661 case RawMatrixer: 662 rm := m.RawMatrix() 663 min := math.Inf(1) 664 for i := 0; i < rm.Rows; i++ { 665 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 666 if v < min { 667 min = v 668 } 669 } 670 } 671 return min 672 case RawTriangular: 673 rm := m.RawTriangular() 674 // The min of a triangular is at most 0 unless the size is 1. 675 if rm.N == 1 { 676 return rm.Data[0] 677 } 678 min := 0.0 679 if rm.Uplo == blas.Upper { 680 for i := 0; i < rm.N; i++ { 681 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 682 if v < min { 683 min = v 684 } 685 } 686 } 687 return min 688 } 689 for i := 0; i < rm.N; i++ { 690 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] { 691 if v < min { 692 min = v 693 } 694 } 695 } 696 return min 697 case RawSymmetricer: 698 rm := m.RawSymmetric() 699 if rm.Uplo != blas.Upper { 700 panic(badSymTriangle) 701 } 702 min := math.Inf(1) 703 for i := 0; i < rm.N; i++ { 704 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] { 705 if v < min { 706 min = v 707 } 708 } 709 } 710 return min 711 default: 712 r, c := aU.Dims() 713 min := math.Inf(1) 714 for i := 0; i < r; i++ { 715 for j := 0; j < c; j++ { 716 v := aU.At(i, j) 717 if v < min { 718 min = v 719 } 720 } 721 } 722 return min 723 } 724} 725 726// Norm returns the specified (induced) norm of the matrix a. See 727// https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm. 728// 729// Valid norms are: 730// 1 - The maximum absolute column sum 731// 2 - Frobenius norm, the square root of the sum of the squares of the elements. 732// Inf - The maximum absolute row sum. 733// Norm will panic with ErrNormOrder if an illegal norm order is specified and 734// with matrix.ErrShape if the matrix has zero size. 735func Norm(a Matrix, norm float64) float64 { 736 r, c := a.Dims() 737 if r == 0 || c == 0 { 738 panic(ErrShape) 739 } 740 aU, aTrans := untranspose(a) 741 var work []float64 742 switch rma := aU.(type) { 743 case RawMatrixer: 744 rm := rma.RawMatrix() 745 n := normLapack(norm, aTrans) 746 if n == lapack.MaxColumnSum { 747 work = getFloats(rm.Cols, false) 748 defer putFloats(work) 749 } 750 return lapack64.Lange(n, rm, work) 751 case RawTriangular: 752 rm := rma.RawTriangular() 753 n := normLapack(norm, aTrans) 754 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum { 755 work = getFloats(rm.N, false) 756 defer putFloats(work) 757 } 758 return lapack64.Lantr(n, rm, work) 759 case RawSymmetricer: 760 rm := rma.RawSymmetric() 761 n := normLapack(norm, aTrans) 762 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum { 763 work = getFloats(rm.N, false) 764 defer putFloats(work) 765 } 766 return lapack64.Lansy(n, rm, work) 767 case *VecDense: 768 rv := rma.RawVector() 769 switch norm { 770 default: 771 panic("unreachable") 772 case 1: 773 if aTrans { 774 imax := blas64.Iamax(rv) 775 return math.Abs(rma.At(imax, 0)) 776 } 777 return blas64.Asum(rv) 778 case 2: 779 return blas64.Nrm2(rv) 780 case math.Inf(1): 781 if aTrans { 782 return blas64.Asum(rv) 783 } 784 imax := blas64.Iamax(rv) 785 return math.Abs(rma.At(imax, 0)) 786 } 787 } 788 switch norm { 789 default: 790 panic("unreachable") 791 case 1: 792 var max float64 793 for j := 0; j < c; j++ { 794 var sum float64 795 for i := 0; i < r; i++ { 796 sum += math.Abs(a.At(i, j)) 797 } 798 if sum > max { 799 max = sum 800 } 801 } 802 return max 803 case 2: 804 var sum float64 805 for i := 0; i < r; i++ { 806 for j := 0; j < c; j++ { 807 v := a.At(i, j) 808 sum += v * v 809 } 810 } 811 return math.Sqrt(sum) 812 case math.Inf(1): 813 var max float64 814 for i := 0; i < r; i++ { 815 var sum float64 816 for j := 0; j < c; j++ { 817 sum += math.Abs(a.At(i, j)) 818 } 819 if sum > max { 820 max = sum 821 } 822 } 823 return max 824 } 825} 826 827// normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm. 828func normLapack(norm float64, aTrans bool) lapack.MatrixNorm { 829 switch norm { 830 case 1: 831 n := lapack.MaxColumnSum 832 if aTrans { 833 n = lapack.MaxRowSum 834 } 835 return n 836 case 2: 837 return lapack.Frobenius 838 case math.Inf(1): 839 n := lapack.MaxRowSum 840 if aTrans { 841 n = lapack.MaxColumnSum 842 } 843 return n 844 default: 845 panic(ErrNormOrder) 846 } 847} 848 849// Sum returns the sum of the elements of the matrix. 850func Sum(a Matrix) float64 { 851 // TODO(btracey): Add a fast path for the other supported matrix types. 852 853 r, c := a.Dims() 854 var sum float64 855 aU, _ := untranspose(a) 856 if rma, ok := aU.(RawMatrixer); ok { 857 rm := rma.RawMatrix() 858 for i := 0; i < rm.Rows; i++ { 859 for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] { 860 sum += v 861 } 862 } 863 return sum 864 } 865 for i := 0; i < r; i++ { 866 for j := 0; j < c; j++ { 867 sum += a.At(i, j) 868 } 869 } 870 return sum 871} 872 873// A Tracer can compute the trace of the matrix. Trace must panic if the 874// matrix is not square. 875type Tracer interface { 876 Trace() float64 877} 878 879// Trace returns the trace of the matrix. Trace will panic if the 880// matrix is not square. 881func Trace(a Matrix) float64 { 882 m, _ := untransposeExtract(a) 883 if t, ok := m.(Tracer); ok { 884 return t.Trace() 885 } 886 r, c := a.Dims() 887 if r != c { 888 panic(ErrSquare) 889 } 890 var v float64 891 for i := 0; i < r; i++ { 892 v += a.At(i, i) 893 } 894 return v 895} 896 897func min(a, b int) int { 898 if a < b { 899 return a 900 } 901 return b 902} 903 904func max(a, b int) int { 905 if a > b { 906 return a 907 } 908 return b 909} 910 911// use returns a float64 slice with l elements, using f if it 912// has the necessary capacity, otherwise creating a new slice. 913func use(f []float64, l int) []float64 { 914 if l <= cap(f) { 915 return f[:l] 916 } 917 return make([]float64, l) 918} 919 920// useZeroed returns a float64 slice with l elements, using f if it 921// has the necessary capacity, otherwise creating a new slice. The 922// elements of the returned slice are guaranteed to be zero. 923func useZeroed(f []float64, l int) []float64 { 924 if l <= cap(f) { 925 f = f[:l] 926 zero(f) 927 return f 928 } 929 return make([]float64, l) 930} 931 932// zero zeros the given slice's elements. 933func zero(f []float64) { 934 for i := range f { 935 f[i] = 0 936 } 937} 938 939// useInt returns an int slice with l elements, using i if it 940// has the necessary capacity, otherwise creating a new slice. 941func useInt(i []int, l int) []int { 942 if l <= cap(i) { 943 return i[:l] 944 } 945 return make([]int, l) 946} 947