1// Copyright ©2013 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/floats" 13 "gonum.org/v1/gonum/lapack" 14 "gonum.org/v1/gonum/lapack/lapack64" 15) 16 17const ( 18 badSliceLength = "mat: improper slice length" 19 badLU = "mat: invalid LU factorization" 20) 21 22// LU is a type for creating and using the LU factorization of a matrix. 23type LU struct { 24 lu *Dense 25 pivot []int 26 cond float64 27} 28 29// updateCond updates the stored condition number of the matrix. anorm is the 30// norm of the original matrix. If anorm is negative it will be estimated. 31func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) { 32 n := lu.lu.mat.Cols 33 work := getFloats(4*n, false) 34 defer putFloats(work) 35 iwork := getInts(n, false) 36 defer putInts(iwork) 37 if anorm < 0 { 38 // This is an approximation. By the definition of a norm, 39 // |AB| <= |A| |B|. 40 // Since A = L*U, we get for the condition number κ that 41 // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|, 42 // so this will overestimate the condition number somewhat. 43 // The norm of the original factorized matrix cannot be stored 44 // because of update possibilities. 45 u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) 46 l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) 47 unorm := lapack64.Lantr(norm, u.mat, work) 48 lnorm := lapack64.Lantr(norm, l.mat, work) 49 anorm = unorm * lnorm 50 } 51 v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork) 52 lu.cond = 1 / v 53} 54 55// Factorize computes the LU factorization of the square matrix a and stores the 56// result. The LU decomposition will complete regardless of the singularity of a. 57// 58// The LU factorization is computed with pivoting, and so really the decomposition 59// is a PLU decomposition where P is a permutation matrix. The individual matrix 60// factors can be extracted from the factorization using the Permutation method 61// on Dense, and the LU LTo and UTo methods. 62func (lu *LU) Factorize(a Matrix) { 63 lu.factorize(a, CondNorm) 64} 65 66func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) { 67 r, c := a.Dims() 68 if r != c { 69 panic(ErrSquare) 70 } 71 if lu.lu == nil { 72 lu.lu = NewDense(r, r, nil) 73 } else { 74 lu.lu.Reset() 75 lu.lu.reuseAs(r, r) 76 } 77 lu.lu.Copy(a) 78 if cap(lu.pivot) < r { 79 lu.pivot = make([]int, r) 80 } 81 lu.pivot = lu.pivot[:r] 82 work := getFloats(r, false) 83 anorm := lapack64.Lange(norm, lu.lu.mat, work) 84 putFloats(work) 85 lapack64.Getrf(lu.lu.mat, lu.pivot) 86 lu.updateCond(anorm, norm) 87} 88 89// isValid returns whether the receiver contains a factorization. 90func (lu *LU) isValid() bool { 91 return lu.lu != nil && !lu.lu.IsZero() 92} 93 94// Cond returns the condition number for the factorized matrix. 95// Cond will panic if the receiver does not contain a factorization. 96func (lu *LU) Cond() float64 { 97 if !lu.isValid() { 98 panic(badLU) 99 } 100 return lu.cond 101} 102 103// Reset resets the factorization so that it can be reused as the receiver of a 104// dimensionally restricted operation. 105func (lu *LU) Reset() { 106 if lu.lu != nil { 107 lu.lu.Reset() 108 } 109 lu.pivot = lu.pivot[:0] 110} 111 112func (lu *LU) isZero() bool { 113 return len(lu.pivot) == 0 114} 115 116// Det returns the determinant of the matrix that has been factorized. In many 117// expressions, using LogDet will be more numerically stable. 118// Det will panic if the receiver does not contain a factorization. 119func (lu *LU) Det() float64 { 120 det, sign := lu.LogDet() 121 return math.Exp(det) * sign 122} 123 124// LogDet returns the log of the determinant and the sign of the determinant 125// for the matrix that has been factorized. Numerical stability in product and 126// division expressions is generally improved by working in log space. 127// LogDet will panic if the receiver does not contain a factorization. 128func (lu *LU) LogDet() (det float64, sign float64) { 129 if !lu.isValid() { 130 panic(badLU) 131 } 132 133 _, n := lu.lu.Dims() 134 logDiag := getFloats(n, false) 135 defer putFloats(logDiag) 136 sign = 1.0 137 for i := 0; i < n; i++ { 138 v := lu.lu.at(i, i) 139 if v < 0 { 140 sign *= -1 141 } 142 if lu.pivot[i] != i { 143 sign *= -1 144 } 145 logDiag[i] = math.Log(math.Abs(v)) 146 } 147 return floats.Sum(logDiag), sign 148} 149 150// Pivot returns pivot indices that enable the construction of the permutation 151// matrix P (see Dense.Permutation). If swaps == nil, then new memory will be 152// allocated, otherwise the length of the input must be equal to the size of the 153// factorized matrix. 154// Pivot will panic if the receiver does not contain a factorization. 155func (lu *LU) Pivot(swaps []int) []int { 156 if !lu.isValid() { 157 panic(badLU) 158 } 159 160 _, n := lu.lu.Dims() 161 if swaps == nil { 162 swaps = make([]int, n) 163 } 164 if len(swaps) != n { 165 panic(badSliceLength) 166 } 167 // Perform the inverse of the row swaps in order to find the final 168 // row swap position. 169 for i := range swaps { 170 swaps[i] = i 171 } 172 for i := n - 1; i >= 0; i-- { 173 v := lu.pivot[i] 174 swaps[i], swaps[v] = swaps[v], swaps[i] 175 } 176 return swaps 177} 178 179// RankOne updates an LU factorization as if a rank-one update had been applied to 180// the original matrix A, storing the result into the receiver. That is, if in 181// the original LU decomposition P * L * U = A, in the updated decomposition 182// P * L * U = A + alpha * x * y^T. 183// RankOne will panic if orig does not contain a factorization. 184func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) { 185 if !orig.isValid() { 186 panic(badLU) 187 } 188 189 // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix 190 // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. 191 // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf 192 _, n := orig.lu.Dims() 193 if r, c := x.Dims(); r != n || c != 1 { 194 panic(ErrShape) 195 } 196 if r, c := y.Dims(); r != n || c != 1 { 197 panic(ErrShape) 198 } 199 if orig != lu { 200 if lu.isZero() { 201 if cap(lu.pivot) < n { 202 lu.pivot = make([]int, n) 203 } 204 lu.pivot = lu.pivot[:n] 205 if lu.lu == nil { 206 lu.lu = NewDense(n, n, nil) 207 } else { 208 lu.lu.reuseAs(n, n) 209 } 210 } else if len(lu.pivot) != n { 211 panic(ErrShape) 212 } 213 copy(lu.pivot, orig.pivot) 214 lu.lu.Copy(orig.lu) 215 } 216 217 xs := getFloats(n, false) 218 defer putFloats(xs) 219 ys := getFloats(n, false) 220 defer putFloats(ys) 221 for i := 0; i < n; i++ { 222 xs[i] = x.AtVec(i) 223 ys[i] = y.AtVec(i) 224 } 225 226 // Adjust for the pivoting in the LU factorization 227 for i, v := range lu.pivot { 228 xs[i], xs[v] = xs[v], xs[i] 229 } 230 231 lum := lu.lu.mat 232 omega := alpha 233 for j := 0; j < n; j++ { 234 ujj := lum.Data[j*lum.Stride+j] 235 ys[j] /= ujj 236 theta := 1 + xs[j]*ys[j]*omega 237 beta := omega * ys[j] / theta 238 gamma := omega * xs[j] 239 omega -= beta * gamma 240 lum.Data[j*lum.Stride+j] *= theta 241 for i := j + 1; i < n; i++ { 242 xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] 243 tmp := ys[i] 244 ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] 245 lum.Data[i*lum.Stride+j] += beta * xs[i] 246 lum.Data[j*lum.Stride+i] += gamma * tmp 247 } 248 } 249 lu.updateCond(-1, CondNorm) 250} 251 252// LTo extracts the lower triangular matrix from an LU factorization. 253// If dst is nil, a new matrix is allocated. The resulting L matrix is returned. 254// LTo will panic if the receiver does not contain a factorization. 255func (lu *LU) LTo(dst *TriDense) *TriDense { 256 if !lu.isValid() { 257 panic(badLU) 258 } 259 260 _, n := lu.lu.Dims() 261 if dst == nil { 262 dst = NewTriDense(n, Lower, nil) 263 } else { 264 dst.reuseAs(n, Lower) 265 } 266 // Extract the lower triangular elements. 267 for i := 0; i < n; i++ { 268 for j := 0; j < i; j++ { 269 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 270 } 271 } 272 // Set ones on the diagonal. 273 for i := 0; i < n; i++ { 274 dst.mat.Data[i*dst.mat.Stride+i] = 1 275 } 276 return dst 277} 278 279// UTo extracts the upper triangular matrix from an LU factorization. 280// If dst is nil, a new matrix is allocated. The resulting U matrix is returned. 281// UTo will panic if the receiver does not contain a factorization. 282func (lu *LU) UTo(dst *TriDense) *TriDense { 283 if !lu.isValid() { 284 panic(badLU) 285 } 286 287 _, n := lu.lu.Dims() 288 if dst == nil { 289 dst = NewTriDense(n, Upper, nil) 290 } else { 291 dst.reuseAs(n, Upper) 292 } 293 // Extract the upper triangular elements. 294 for i := 0; i < n; i++ { 295 for j := i; j < n; j++ { 296 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 297 } 298 } 299 return dst 300} 301 302// Permutation constructs an r×r permutation matrix with the given row swaps. 303// A permutation matrix has exactly one element equal to one in each row and column 304// and all other elements equal to zero. swaps[i] specifies the row with which 305// i will be swapped, which is equivalent to the non-zero column of row i. 306func (m *Dense) Permutation(r int, swaps []int) { 307 m.reuseAs(r, r) 308 for i := 0; i < r; i++ { 309 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r]) 310 v := swaps[i] 311 if v < 0 || v >= r { 312 panic(ErrRowAccess) 313 } 314 m.mat.Data[i*m.mat.Stride+v] = 1 315 } 316} 317 318// SolveTo solves a system of linear equations using the LU decomposition of a matrix. 319// It computes 320// A * X = B if trans == false 321// A^T * X = B if trans == true 322// In both cases, A is represented in LU factorized form, and the matrix X is 323// stored into dst. 324// 325// If A is singular or near-singular a Condition error is returned. See 326// the documentation for Condition for more information. 327// SolveTo will panic if the receiver does not contain a factorization. 328func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error { 329 if !lu.isValid() { 330 panic(badLU) 331 } 332 333 _, n := lu.lu.Dims() 334 br, bc := b.Dims() 335 if br != n { 336 panic(ErrShape) 337 } 338 // TODO(btracey): Should test the condition number instead of testing that 339 // the determinant is exactly zero. 340 if lu.Det() == 0 { 341 return Condition(math.Inf(1)) 342 } 343 344 dst.reuseAs(n, bc) 345 bU, _ := untranspose(b) 346 var restore func() 347 if dst == bU { 348 dst, restore = dst.isolatedWorkspace(bU) 349 defer restore() 350 } else if rm, ok := bU.(RawMatrixer); ok { 351 dst.checkOverlap(rm.RawMatrix()) 352 } 353 354 dst.Copy(b) 355 t := blas.NoTrans 356 if trans { 357 t = blas.Trans 358 } 359 lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot) 360 if lu.cond > ConditionTolerance { 361 return Condition(lu.cond) 362 } 363 return nil 364} 365 366// SolveVecTo solves a system of linear equations using the LU decomposition of a matrix. 367// It computes 368// A * x = b if trans == false 369// A^T * x = b if trans == true 370// In both cases, A is represented in LU factorized form, and the vector x is 371// stored into dst. 372// 373// If A is singular or near-singular a Condition error is returned. See 374// the documentation for Condition for more information. 375// SolveVecTo will panic if the receiver does not contain a factorization. 376func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 377 if !lu.isValid() { 378 panic(badLU) 379 } 380 381 _, n := lu.lu.Dims() 382 if br, bc := b.Dims(); br != n || bc != 1 { 383 panic(ErrShape) 384 } 385 switch rv := b.(type) { 386 default: 387 dst.reuseAs(n) 388 return lu.SolveTo(dst.asDense(), trans, b) 389 case RawVectorer: 390 if dst != b { 391 dst.checkOverlap(rv.RawVector()) 392 } 393 // TODO(btracey): Should test the condition number instead of testing that 394 // the determinant is exactly zero. 395 if lu.Det() == 0 { 396 return Condition(math.Inf(1)) 397 } 398 399 dst.reuseAs(n) 400 var restore func() 401 if dst == b { 402 dst, restore = dst.isolatedWorkspace(b) 403 defer restore() 404 } 405 dst.CopyVec(b) 406 vMat := blas64.General{ 407 Rows: n, 408 Cols: 1, 409 Stride: dst.mat.Inc, 410 Data: dst.mat.Data, 411 } 412 t := blas.NoTrans 413 if trans { 414 t = blas.Trans 415 } 416 lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) 417 if lu.cond > ConditionTolerance { 418 return Condition(lu.cond) 419 } 420 return nil 421 } 422} 423