1// Copyright ©2013 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/lapack" 13 "gonum.org/v1/gonum/lapack/lapack64" 14) 15 16const badQR = "mat: invalid QR factorization" 17 18// QR is a type for creating and using the QR factorization of a matrix. 19type QR struct { 20 qr *Dense 21 tau []float64 22 cond float64 23} 24 25func (qr *QR) updateCond(norm lapack.MatrixNorm) { 26 // Since A = Q*R, and Q is orthogonal, we get for the condition number κ 27 // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Q^T| 28 // = |R| |R^-1| = κ(R), 29 // where we used that fact that Q^-1 = Q^T. However, this assumes that 30 // the matrix norm is invariant under orthogonal transformations which 31 // is not the case for CondNorm. Hopefully the error is negligible: κ 32 // is only a qualitative measure anyway. 33 n := qr.qr.mat.Cols 34 work := getFloats(3*n, false) 35 iwork := getInts(n, false) 36 r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper) 37 v := lapack64.Trcon(norm, r.mat, work, iwork) 38 putFloats(work) 39 putInts(iwork) 40 qr.cond = 1 / v 41} 42 43// Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR 44// factorization always exists even if A is singular. 45// 46// The QR decomposition is a factorization of the matrix A such that A = Q * R. 47// The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix. 48// Q and R can be extracted using the QTo and RTo methods. 49func (qr *QR) Factorize(a Matrix) { 50 qr.factorize(a, CondNorm) 51} 52 53func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) { 54 m, n := a.Dims() 55 if m < n { 56 panic(ErrShape) 57 } 58 k := min(m, n) 59 if qr.qr == nil { 60 qr.qr = &Dense{} 61 } 62 qr.qr.Clone(a) 63 work := []float64{0} 64 qr.tau = make([]float64, k) 65 lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1) 66 67 work = getFloats(int(work[0]), false) 68 lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work)) 69 putFloats(work) 70 qr.updateCond(norm) 71} 72 73// isValid returns whether the receiver contains a factorization. 74func (qr *QR) isValid() bool { 75 return qr.qr != nil && !qr.qr.IsZero() 76} 77 78// Cond returns the condition number for the factorized matrix. 79// Cond will panic if the receiver does not contain a factorization. 80func (qr *QR) Cond() float64 { 81 if !qr.isValid() { 82 panic(badQR) 83 } 84 return qr.cond 85} 86 87// TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal 88// and upper triangular matrices. 89 90// RTo extracts the m×n upper trapezoidal matrix from a QR decomposition. 91// If dst is nil, a new matrix is allocated. The resulting dst matrix is returned. 92// RTo will panic if the receiver does not contain a factorization. 93func (qr *QR) RTo(dst *Dense) *Dense { 94 if !qr.isValid() { 95 panic(badQR) 96 } 97 98 r, c := qr.qr.Dims() 99 if dst == nil { 100 dst = NewDense(r, c, nil) 101 } else { 102 dst.reuseAs(r, c) 103 } 104 105 // Disguise the QR as an upper triangular 106 t := &TriDense{ 107 mat: blas64.Triangular{ 108 N: c, 109 Stride: qr.qr.mat.Stride, 110 Data: qr.qr.mat.Data, 111 Uplo: blas.Upper, 112 Diag: blas.NonUnit, 113 }, 114 cap: qr.qr.capCols, 115 } 116 dst.Copy(t) 117 118 // Zero below the triangular. 119 for i := r; i < c; i++ { 120 zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c]) 121 } 122 123 return dst 124} 125 126// QTo extracts the m×m orthonormal matrix Q from a QR decomposition. 127// If dst is nil, a new matrix is allocated. The resulting Q matrix is returned. 128// QTo will panic if the receiver does not contain a factorization. 129func (qr *QR) QTo(dst *Dense) *Dense { 130 if !qr.isValid() { 131 panic(badQR) 132 } 133 134 r, _ := qr.qr.Dims() 135 if dst == nil { 136 dst = NewDense(r, r, nil) 137 } else { 138 dst.reuseAsZeroed(r, r) 139 } 140 141 // Set Q = I. 142 for i := 0; i < r*r; i += r + 1 { 143 dst.mat.Data[i] = 1 144 } 145 146 // Construct Q from the elementary reflectors. 147 work := []float64{0} 148 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1) 149 work = getFloats(int(work[0]), false) 150 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work)) 151 putFloats(work) 152 153 return dst 154} 155 156// SolveTo finds a minimum-norm solution to a system of linear equations defined 157// by the matrices A and b, where A is an m×n matrix represented in its QR factorized 158// form. If A is singular or near-singular a Condition error is returned. 159// See the documentation for Condition for more information. 160// 161// The minimization problem solved depends on the input parameters. 162// If trans == false, find X such that ||A*X - B||_2 is minimized. 163// If trans == true, find the minimum norm solution of A^T * X = B. 164// The solution matrix, X, is stored in place into dst. 165// SolveTo will panic if the receiver does not contain a factorization. 166func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error { 167 if !qr.isValid() { 168 panic(badQR) 169 } 170 171 r, c := qr.qr.Dims() 172 br, bc := b.Dims() 173 174 // The QR solve algorithm stores the result in-place into the right hand side. 175 // The storage for the answer must be large enough to hold both b and x. 176 // However, this method's receiver must be the size of x. Copy b, and then 177 // copy the result into m at the end. 178 if trans { 179 if c != br { 180 panic(ErrShape) 181 } 182 dst.reuseAs(r, bc) 183 } else { 184 if r != br { 185 panic(ErrShape) 186 } 187 dst.reuseAs(c, bc) 188 } 189 // Do not need to worry about overlap between m and b because x has its own 190 // independent storage. 191 w := getWorkspace(max(r, c), bc, false) 192 w.Copy(b) 193 t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat 194 if trans { 195 ok := lapack64.Trtrs(blas.Trans, t, w.mat) 196 if !ok { 197 return Condition(math.Inf(1)) 198 } 199 for i := c; i < r; i++ { 200 zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc]) 201 } 202 work := []float64{0} 203 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1) 204 work = getFloats(int(work[0]), false) 205 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 206 putFloats(work) 207 } else { 208 work := []float64{0} 209 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1) 210 work = getFloats(int(work[0]), false) 211 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 212 putFloats(work) 213 214 ok := lapack64.Trtrs(blas.NoTrans, t, w.mat) 215 if !ok { 216 return Condition(math.Inf(1)) 217 } 218 } 219 // X was set above to be the correct size for the result. 220 dst.Copy(w) 221 putWorkspace(w) 222 if qr.cond > ConditionTolerance { 223 return Condition(qr.cond) 224 } 225 return nil 226} 227 228// SolveVecTo finds a minimum-norm solution to a system of linear equations, 229// Ax = b. 230// See QR.SolveTo for the full documentation. 231// SolveVecTo will panic if the receiver does not contain a factorization. 232func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 233 if !qr.isValid() { 234 panic(badQR) 235 } 236 237 r, c := qr.qr.Dims() 238 if _, bc := b.Dims(); bc != 1 { 239 panic(ErrShape) 240 } 241 242 // The Solve implementation is non-trivial, so rather than duplicate the code, 243 // instead recast the VecDenses as Dense and call the matrix code. 244 bm := Matrix(b) 245 if rv, ok := b.(RawVectorer); ok { 246 bmat := rv.RawVector() 247 if dst != b { 248 dst.checkOverlap(bmat) 249 } 250 b := VecDense{mat: bmat} 251 bm = b.asDense() 252 } 253 if trans { 254 dst.reuseAs(r) 255 } else { 256 dst.reuseAs(c) 257 } 258 return qr.SolveTo(dst.asDense(), trans, bm) 259 260} 261