1 /////////////////////////////////////////////////////////////////////////////// 2 // // 3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis // 4 // Copyright (C) 1998 - 2016 // 5 // All rights reserved // 6 // // 7 // The project is hosted at https://code.google.com/p/tmv-cpp/ // 8 // where you can find the current version and current documention. // 9 // // 10 // For concerns or problems with the software, Mike may be contacted at // 11 // mike_jarvis17 [at] gmail. // 12 // // 13 // This software is licensed under a FreeBSD license. The file // 14 // TMV_LICENSE should have bee included with this distribution. // 15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/. // 16 // // 17 // Essentially, you can use this software however you want provided that // 18 // you include the TMV_LICENSE file in any distribution that uses it. // 19 // // 20 /////////////////////////////////////////////////////////////////////////////// 21 22 23 //--------------------------------------------------------------------------- 24 // 25 // This file defines the TMV BaseMatrix class. 26 // 27 // This base class defines some of the things that all 28 // matrices need to be able to do, as well as some of the 29 // arithmetic operations (those that return a Vector). 30 // This should be used as the base class for generic 31 // matrices as well as any special ones (eg. sparse, 32 // symmetric, etc.) 33 // 34 // 35 36 #ifndef TMV_BaseMatrix_H 37 #define TMV_BaseMatrix_H 38 39 #include "tmv/TMV_Base.h" 40 #include "tmv/TMV_BaseVector.h" 41 #include "tmv/TMV_IOStyle.h" 42 43 namespace tmv { 44 45 template <typename T> 46 class BaseMatrix; 47 48 template <typename T> 49 class GenMatrix; 50 51 template <typename T, int A=0> 52 class ConstMatrixView; 53 54 template <typename T, int A=0> 55 class MatrixView; 56 57 template <typename T, int A=0> 58 class Matrix; 59 60 template <typename T, ptrdiff_t M, ptrdiff_t N, int A=0> 61 class SmallMatrix; 62 63 template <typename T, ptrdiff_t M, ptrdiff_t N> 64 class SmallMatrixComposite; 65 66 template <typename T> 67 class Divider; 68 69 template <typename T> 70 struct AssignableToMatrix 71 { 72 typedef TMV_RealType(T) RT; 73 typedef TMV_ComplexType(T) CT; 74 75 virtual ptrdiff_t colsize() const = 0; 76 virtual ptrdiff_t rowsize() const = 0; ncolsAssignableToMatrix77 inline ptrdiff_t ncols() const 78 { return rowsize(); } nrowsAssignableToMatrix79 inline ptrdiff_t nrows() const 80 { return colsize(); } isSquareAssignableToMatrix81 inline bool isSquare() const 82 { return colsize() == rowsize(); } 83 84 virtual void assignToM(MatrixView<RT> m) const = 0; 85 virtual void assignToM(MatrixView<CT> m) const = 0; 86 ~AssignableToMatrixAssignableToMatrix87 virtual inline ~AssignableToMatrix() {} 88 }; 89 90 template <typename T> 91 class BaseMatrix : virtual public AssignableToMatrix<T> 92 { 93 public : 94 typedef TMV_RealType(T) RT; 95 96 // 97 // Access Functions 98 // 99 100 using AssignableToMatrix<T>::colsize; 101 using AssignableToMatrix<T>::rowsize; 102 103 // 104 // Functions of Matrix 105 // 106 107 virtual T det() const = 0; 108 virtual RT logDet(T* sign=0) const = 0; 109 virtual T trace() const = 0; 110 virtual T sumElements() const = 0; 111 virtual RT sumAbsElements() const = 0; 112 virtual RT sumAbs2Elements() const = 0; 113 114 virtual RT norm() const = 0; 115 virtual RT normSq(const RT scale = RT(1)) const = 0; 116 virtual RT normF() const = 0; 117 virtual RT norm1() const = 0; 118 virtual RT norm2() const = 0; 119 virtual RT doNorm2() const = 0; 120 virtual RT normInf() const = 0; 121 virtual RT maxAbsElement() const = 0; 122 virtual RT maxAbs2Element() const = 0; 123 124 // 125 // I/O: Write 126 // 127 128 virtual void write(const TMV_Writer& writer) const = 0; 129 ~BaseMatrix()130 virtual inline ~BaseMatrix() {} 131 132 }; // BaseMatrix 133 134 template <typename T> 135 class DivHelper : virtual public AssignableToMatrix<T> 136 { 137 public: 138 139 typedef TMV_RealType(T) RT; 140 141 // 142 // Constructors 143 // 144 145 DivHelper(); 146 // Cannot do this inline, since need to delete pdiv, 147 // and I only define DivImpl in BaseMatrix.cpp. 148 virtual ~DivHelper(); 149 150 using AssignableToMatrix<T>::colsize; 151 using AssignableToMatrix<T>::rowsize; 152 det()153 T det() const 154 { 155 TMVAssert(rowsize() == colsize()); 156 return doDet(); 157 } 158 logDet(T * sign)159 RT logDet(T* sign) const 160 { 161 TMVAssert(rowsize() == colsize()); 162 return doLogDet(sign); 163 } 164 makeInverse(MatrixView<T> minv)165 void makeInverse(MatrixView<T> minv) const 166 { 167 TMVAssert(minv.colsize() == rowsize()); 168 TMVAssert(minv.rowsize() == colsize()); 169 doMakeInverse(minv); 170 } 171 172 template <typename T1> makeInverse(MatrixView<T1> minv)173 inline void makeInverse(MatrixView<T1> minv) const 174 { 175 TMVAssert(minv.colsize() == rowsize()); 176 TMVAssert(minv.rowsize() == colsize()); 177 doMakeInverse(minv); 178 } 179 180 template <typename T1, int A> makeInverse(Matrix<T1,A> & minv)181 inline void makeInverse(Matrix<T1,A>& minv) const 182 { 183 TMVAssert(minv.colsize() == rowsize()); 184 TMVAssert(minv.rowsize() == colsize()); 185 doMakeInverse(minv.view()); 186 } 187 makeInverseATA(MatrixView<T> ata)188 inline void makeInverseATA(MatrixView<T> ata) const 189 { 190 TMVAssert(ata.colsize() == 191 (rowsize() < colsize() ? rowsize() : colsize())); 192 TMVAssert(ata.rowsize() == 193 (rowsize() < colsize() ? rowsize() : colsize())); 194 doMakeInverseATA(ata); 195 } 196 197 template <int A> makeInverseATA(Matrix<T,A> & ata)198 inline void makeInverseATA(Matrix<T,A>& ata) const 199 { 200 TMVAssert(ata.colsize() == 201 (rowsize() < colsize() ? rowsize() : colsize())); 202 TMVAssert(ata.rowsize() == 203 (rowsize() < colsize() ? rowsize() : colsize())); 204 doMakeInverseATA(ata.view()); 205 } 206 isSingular()207 inline bool isSingular() const 208 { return doIsSingular(); } 209 norm2()210 inline RT norm2() const 211 { 212 TMVAssert(divIsSet() && getDivType() == SV); 213 return doNorm2(); 214 } 215 condition()216 inline RT condition() const 217 { 218 TMVAssert(divIsSet() && getDivType() == SV); 219 return doCondition(); 220 } 221 222 // m^-1 * v -> v 223 template <typename T1> LDivEq(VectorView<T1> v)224 inline void LDivEq(VectorView<T1> v) const 225 { 226 TMVAssert(colsize() == rowsize()); 227 TMVAssert(colsize() == v.size()); 228 doLDivEq(v); 229 } 230 231 template <typename T1> LDivEq(MatrixView<T1> m)232 inline void LDivEq(MatrixView<T1> m) const 233 { 234 TMVAssert(colsize() == rowsize()); 235 TMVAssert(colsize() == m.colsize()); 236 doLDivEq(m); 237 } 238 239 // v * m^-1 -> v 240 template <typename T1> RDivEq(VectorView<T1> v)241 inline void RDivEq(VectorView<T1> v) const 242 { 243 TMVAssert(colsize() == rowsize()); 244 TMVAssert(colsize() == v.size()); 245 doRDivEq(v); 246 } 247 248 template <typename T1> RDivEq(MatrixView<T1> m)249 inline void RDivEq(MatrixView<T1> m) const 250 { 251 TMVAssert(colsize() == rowsize()); 252 TMVAssert(colsize() == m.rowsize()); 253 doRDivEq(m); 254 } 255 256 // m^-1 * v1 -> v0 257 template <typename T1, typename T0> LDiv(const GenVector<T1> & v1,VectorView<T0> v0)258 inline void LDiv( 259 const GenVector<T1>& v1, VectorView<T0> v0) const 260 { 261 TMVAssert(rowsize() == v0.size()); 262 TMVAssert(colsize() == v1.size()); 263 doLDiv(v1,v0); 264 } 265 266 template <typename T1, typename T0> LDiv(const GenMatrix<T1> & m1,MatrixView<T0> m0)267 inline void LDiv( 268 const GenMatrix<T1>& m1, MatrixView<T0> m0) const 269 { 270 TMVAssert(rowsize() == m0.colsize()); 271 TMVAssert(colsize() == m1.colsize()); 272 TMVAssert(m1.rowsize() == m0.rowsize()); 273 doLDiv(m1,m0); 274 } 275 276 // v1 * m^-1 -> v0 277 template <typename T1, typename T0> RDiv(const GenVector<T1> & v1,VectorView<T0> v0)278 inline void RDiv( 279 const GenVector<T1>& v1, VectorView<T0> v0) const 280 { 281 TMVAssert(rowsize() == v1.size()); 282 TMVAssert(colsize() == v0.size()); 283 doRDiv(v1,v0); 284 } 285 286 template <typename T1, typename T0> RDiv(const GenMatrix<T1> & m1,MatrixView<T0> m0)287 inline void RDiv( 288 const GenMatrix<T1>& m1, MatrixView<T0> m0) const 289 { 290 TMVAssert(rowsize() == m1.rowsize()); 291 TMVAssert(colsize() == m0.rowsize()); 292 TMVAssert(m1.colsize() == m0.colsize()); 293 doRDiv(m1,m0); 294 } 295 296 // 297 // Division Control 298 // 299 300 void divideUsing(DivType dt) const; 301 302 void divideInPlace() const; 303 void dontDivideInPlace() const; 304 void saveDiv() const; 305 void dontSaveDiv() const; 306 307 // setDiv is defined in the derived class. 308 virtual void setDiv() const = 0; 309 void unsetDiv() const; 310 void resetDiv() const; 311 312 DivType getDivType() const; 313 bool divIsInPlace() const; 314 bool divIsSaved() const; 315 bool divIsSet() const; 316 317 bool checkDecomp(std::ostream* fout=0) const; 318 bool checkDecomp(const BaseMatrix<T>& m2, std::ostream* fout=0) const; 319 320 protected : 321 322 void doneDiv() const; 323 const Divider<T>* getDiv() const; 324 void resetDivType() const; 325 326 // Two more that need to be defined in the derived class: 327 virtual const BaseMatrix<T>& getMatrix() const = 0; 328 329 mutable auto_ptr<Divider<T> > divider; 330 mutable DivType divtype; 331 332 private : 333 334 DivHelper(const DivHelper<T>&); 335 DivHelper<T>& operator=(const DivHelper<T>&); 336 337 T doDet() const; 338 RT doLogDet(T* sign) const; 339 template <typename T1> 340 void doMakeInverse(MatrixView<T1> minv) const; 341 void doMakeInverseATA(MatrixView<T> minv) const; 342 bool doIsSingular() const; 343 RT doNorm2() const; 344 RT doCondition() const; 345 template <typename T1> 346 void doLDivEq(VectorView<T1> v) const; 347 template <typename T1> 348 void doLDivEq(MatrixView<T1> m) const; 349 template <typename T1> 350 void doRDivEq(VectorView<T1> v) const; 351 template <typename T1> 352 void doRDivEq(MatrixView<T1> m) const; 353 template <typename T1, typename T0> 354 void doLDiv( 355 const GenVector<T1>& v1, VectorView<T0> v0) const; 356 template <typename T1, typename T0> 357 void doLDiv( 358 const GenMatrix<T1>& m1, MatrixView<T0> m0) const; 359 template <typename T1, typename T0> 360 void doRDiv( 361 const GenVector<T1>& v1, VectorView<T0> v0) const; 362 template <typename T1, typename T0> 363 void doRDiv( 364 const GenMatrix<T1>& m1, MatrixView<T0> m0) const; 365 366 }; // DivHelper 367 368 // 369 // Functions of Matrices: 370 // 371 372 template <typename T> Det(const BaseMatrix<T> & m)373 inline T Det(const BaseMatrix<T>& m) 374 { return m.det(); } 375 376 template <typename T> TMV_RealType(T)377 inline TMV_RealType(T) LogDet(const BaseMatrix<T>& m) 378 { return m.logDet(); } 379 380 template <typename T> Trace(const BaseMatrix<T> & m)381 inline T Trace(const BaseMatrix<T>& m) 382 { return m.trace(); } 383 384 template <typename T> SumElements(const BaseMatrix<T> & m)385 inline T SumElements(const BaseMatrix<T>& m) 386 { return m.sumElements(); } 387 388 template <typename T> TMV_RealType(T)389 inline TMV_RealType(T) SumAbsElements(const BaseMatrix<T>& m) 390 { return m.sumAbsElements(); } 391 392 template <typename T> TMV_RealType(T)393 inline TMV_RealType(T) SumAbs2Elements(const BaseMatrix<T>& m) 394 { return m.sumAbs2Elements(); } 395 396 template <typename T> TMV_RealType(T)397 inline TMV_RealType(T) Norm(const BaseMatrix<T>& m) 398 { return m.norm(); } 399 400 template <typename T> TMV_RealType(T)401 inline TMV_RealType(T) NormSq(const BaseMatrix<T>& m) 402 { return m.normSq(); } 403 404 template <typename T> TMV_RealType(T)405 inline TMV_RealType(T) NormF(const BaseMatrix<T>& m) 406 { return m.normF(); } 407 408 template <typename T> TMV_RealType(T)409 inline TMV_RealType(T) Norm1(const BaseMatrix<T>& m) 410 { return m.norm1(); } 411 412 template <typename T> TMV_RealType(T)413 inline TMV_RealType(T) Norm2(const BaseMatrix<T>& m) 414 { return m.norm2(); } 415 416 template <typename T> TMV_RealType(T)417 inline TMV_RealType(T) NormInf(const BaseMatrix<T>& m) 418 { return m.normInf(); } 419 420 template <typename T> TMV_RealType(T)421 inline TMV_RealType(T) MaxAbsElement(const BaseMatrix<T>& m) 422 { return m.maxAbsElement(); } 423 424 template <typename T> TMV_RealType(T)425 inline TMV_RealType(T) MaxAbs2Element(const BaseMatrix<T>& m) 426 { return m.maxAbs2Element(); } 427 428 429 // 430 // I/O 431 // 432 433 template <typename T> 434 inline std::ostream& operator<<( 435 const TMV_Writer& writer, const BaseMatrix<T>& m) 436 { m.write(writer); return writer.getos(); } 437 438 template <typename T> 439 inline std::ostream& operator<<( 440 std::ostream& os, const BaseMatrix<T>& m) 441 { return os << IOStyle() << m; } 442 443 444 template <typename T, int A> TMV_Text(const Matrix<T,A> &)445 inline std::string TMV_Text(const Matrix<T,A>& ) 446 { 447 return std::string("Matrix<") + 448 TMV_Text(T()) + "," + Attrib<A>::text() + ">"; 449 } 450 template <typename T> TMV_Text(const GenMatrix<T> &)451 inline std::string TMV_Text(const GenMatrix<T>& ) 452 { 453 return std::string("GenMatrix<") + TMV_Text(T()) + ">"; 454 } 455 template <typename T, int A> TMV_Text(const ConstMatrixView<T,A> &)456 inline std::string TMV_Text(const ConstMatrixView<T,A>& ) 457 { 458 return std::string("ConstMatrixView<") + 459 TMV_Text(T()) + "," + Attrib<A>::text() + ">"; 460 } 461 template <typename T, int A> TMV_Text(const MatrixView<T,A> &)462 inline std::string TMV_Text(const MatrixView<T,A>& ) 463 { 464 return std::string("MatrixView<") + 465 TMV_Text(T()) + "," + Attrib<A>::text() + ">"; 466 } 467 468 } // namespace tmv 469 470 #endif 471