1 ///////////////////////////////////////////////////////////////////////////////
2 //                                                                           //
3 // The Template Matrix/Vector Library for C++ was created by Mike Jarvis     //
4 // Copyright (C) 1998 - 2016                                                 //
5 // All rights reserved                                                       //
6 //                                                                           //
7 // The project is hosted at https://code.google.com/p/tmv-cpp/               //
8 // where you can find the current version and current documention.           //
9 //                                                                           //
10 // For concerns or problems with the software, Mike may be contacted at      //
11 // mike_jarvis17 [at] gmail.                                                 //
12 //                                                                           //
13 // This software is licensed under a FreeBSD license.  The file              //
14 // TMV_LICENSE should have bee included with this distribution.              //
15 // It not, you can get a copy from https://code.google.com/p/tmv-cpp/.       //
16 //                                                                           //
17 // Essentially, you can use this software however you want provided that     //
18 // you include the TMV_LICENSE file in any distribution that uses it.        //
19 //                                                                           //
20 ///////////////////////////////////////////////////////////////////////////////
21 
22 
23 //---------------------------------------------------------------------------
24 //
25 // This file defines the TMV BaseMatrix class.
26 //
27 // This base class defines some of the things that all
28 // matrices need to be able to do, as well as some of the
29 // arithmetic operations (those that return a Vector).
30 // This should be used as the base class for generic
31 // matrices as well as any special ones (eg. sparse,
32 // symmetric, etc.)
33 //
34 //
35 
36 #ifndef TMV_BaseMatrix_H
37 #define TMV_BaseMatrix_H
38 
39 #include "tmv/TMV_Base.h"
40 #include "tmv/TMV_BaseVector.h"
41 #include "tmv/TMV_IOStyle.h"
42 
43 namespace tmv {
44 
45     template <typename T>
46     class BaseMatrix;
47 
48     template <typename T>
49     class GenMatrix;
50 
51     template <typename T, int A=0>
52     class ConstMatrixView;
53 
54     template <typename T, int A=0>
55     class MatrixView;
56 
57     template <typename T, int A=0>
58     class Matrix;
59 
60     template <typename T, ptrdiff_t M, ptrdiff_t N, int A=0>
61     class SmallMatrix;
62 
63     template <typename T, ptrdiff_t M, ptrdiff_t N>
64     class SmallMatrixComposite;
65 
66     template <typename T>
67     class Divider;
68 
69     template <typename T>
70     struct AssignableToMatrix
71     {
72         typedef TMV_RealType(T) RT;
73         typedef TMV_ComplexType(T) CT;
74 
75         virtual ptrdiff_t colsize() const = 0;
76         virtual ptrdiff_t rowsize() const = 0;
ncolsAssignableToMatrix77         inline ptrdiff_t ncols() const
78         { return rowsize(); }
nrowsAssignableToMatrix79         inline ptrdiff_t nrows() const
80         { return colsize(); }
isSquareAssignableToMatrix81         inline bool isSquare() const
82         { return colsize() == rowsize(); }
83 
84         virtual void assignToM(MatrixView<RT> m) const = 0;
85         virtual void assignToM(MatrixView<CT> m) const = 0;
86 
~AssignableToMatrixAssignableToMatrix87         virtual inline ~AssignableToMatrix() {}
88     };
89 
90     template <typename T>
91     class BaseMatrix : virtual public AssignableToMatrix<T>
92     {
93     public :
94         typedef TMV_RealType(T) RT;
95 
96         //
97         // Access Functions
98         //
99 
100         using AssignableToMatrix<T>::colsize;
101         using AssignableToMatrix<T>::rowsize;
102 
103         //
104         // Functions of Matrix
105         //
106 
107         virtual T det() const = 0;
108         virtual RT logDet(T* sign=0) const = 0;
109         virtual T trace() const = 0;
110         virtual T sumElements() const = 0;
111         virtual RT sumAbsElements() const = 0;
112         virtual RT sumAbs2Elements() const = 0;
113 
114         virtual RT norm() const  = 0;
115         virtual RT normSq(const RT scale = RT(1)) const = 0;
116         virtual RT normF() const  = 0;
117         virtual RT norm1() const = 0;
118         virtual RT norm2() const  = 0;
119         virtual RT doNorm2() const  = 0;
120         virtual RT normInf() const = 0;
121         virtual RT maxAbsElement() const = 0;
122         virtual RT maxAbs2Element() const = 0;
123 
124         //
125         // I/O: Write
126         //
127 
128         virtual void write(const TMV_Writer& writer) const = 0;
129 
~BaseMatrix()130         virtual inline ~BaseMatrix() {}
131 
132     }; // BaseMatrix
133 
134     template <typename T>
135     class DivHelper : virtual public AssignableToMatrix<T>
136     {
137     public:
138 
139         typedef TMV_RealType(T) RT;
140 
141         //
142         // Constructors
143         //
144 
145         DivHelper();
146         // Cannot do this inline, since need to delete pdiv,
147         // and I only define DivImpl in BaseMatrix.cpp.
148         virtual ~DivHelper();
149 
150         using AssignableToMatrix<T>::colsize;
151         using AssignableToMatrix<T>::rowsize;
152 
det()153         T det() const
154         {
155             TMVAssert(rowsize() == colsize());
156             return doDet();
157         }
158 
logDet(T * sign)159         RT logDet(T* sign) const
160         {
161             TMVAssert(rowsize() == colsize());
162             return doLogDet(sign);
163         }
164 
makeInverse(MatrixView<T> minv)165         void makeInverse(MatrixView<T> minv) const
166         {
167             TMVAssert(minv.colsize() == rowsize());
168             TMVAssert(minv.rowsize() == colsize());
169             doMakeInverse(minv);
170         }
171 
172         template <typename T1>
makeInverse(MatrixView<T1> minv)173         inline void makeInverse(MatrixView<T1> minv) const
174         {
175             TMVAssert(minv.colsize() == rowsize());
176             TMVAssert(minv.rowsize() == colsize());
177             doMakeInverse(minv);
178         }
179 
180         template <typename T1, int A>
makeInverse(Matrix<T1,A> & minv)181         inline void makeInverse(Matrix<T1,A>& minv) const
182         {
183             TMVAssert(minv.colsize() == rowsize());
184             TMVAssert(minv.rowsize() == colsize());
185             doMakeInverse(minv.view());
186         }
187 
makeInverseATA(MatrixView<T> ata)188         inline void makeInverseATA(MatrixView<T> ata) const
189         {
190             TMVAssert(ata.colsize() ==
191                       (rowsize() < colsize() ? rowsize() : colsize()));
192             TMVAssert(ata.rowsize() ==
193                       (rowsize() < colsize() ? rowsize() : colsize()));
194             doMakeInverseATA(ata);
195         }
196 
197         template <int A>
makeInverseATA(Matrix<T,A> & ata)198         inline void makeInverseATA(Matrix<T,A>& ata) const
199         {
200             TMVAssert(ata.colsize() ==
201                       (rowsize() < colsize() ? rowsize() : colsize()));
202             TMVAssert(ata.rowsize() ==
203                       (rowsize() < colsize() ? rowsize() : colsize()));
204             doMakeInverseATA(ata.view());
205         }
206 
isSingular()207         inline bool isSingular() const
208         { return doIsSingular(); }
209 
norm2()210         inline RT norm2() const
211         {
212             TMVAssert(divIsSet() && getDivType() == SV);
213             return doNorm2();
214         }
215 
condition()216         inline RT condition() const
217         {
218             TMVAssert(divIsSet() && getDivType() == SV);
219             return doCondition();
220         }
221 
222         // m^-1 * v -> v
223         template <typename T1>
LDivEq(VectorView<T1> v)224         inline void LDivEq(VectorView<T1> v) const
225         {
226             TMVAssert(colsize() == rowsize());
227             TMVAssert(colsize() == v.size());
228             doLDivEq(v);
229         }
230 
231         template <typename T1>
LDivEq(MatrixView<T1> m)232         inline void LDivEq(MatrixView<T1> m) const
233         {
234             TMVAssert(colsize() == rowsize());
235             TMVAssert(colsize() == m.colsize());
236             doLDivEq(m);
237         }
238 
239         // v * m^-1 -> v
240         template <typename T1>
RDivEq(VectorView<T1> v)241         inline void RDivEq(VectorView<T1> v) const
242         {
243             TMVAssert(colsize() == rowsize());
244             TMVAssert(colsize() == v.size());
245             doRDivEq(v);
246         }
247 
248         template <typename T1>
RDivEq(MatrixView<T1> m)249         inline void RDivEq(MatrixView<T1> m) const
250         {
251             TMVAssert(colsize() == rowsize());
252             TMVAssert(colsize() == m.rowsize());
253             doRDivEq(m);
254         }
255 
256         // m^-1 * v1 -> v0
257         template <typename T1, typename T0>
LDiv(const GenVector<T1> & v1,VectorView<T0> v0)258         inline void LDiv(
259             const GenVector<T1>& v1, VectorView<T0> v0) const
260         {
261             TMVAssert(rowsize() == v0.size());
262             TMVAssert(colsize() == v1.size());
263             doLDiv(v1,v0);
264         }
265 
266         template <typename T1, typename T0>
LDiv(const GenMatrix<T1> & m1,MatrixView<T0> m0)267         inline void LDiv(
268             const GenMatrix<T1>& m1, MatrixView<T0> m0) const
269         {
270             TMVAssert(rowsize() == m0.colsize());
271             TMVAssert(colsize() == m1.colsize());
272             TMVAssert(m1.rowsize() == m0.rowsize());
273             doLDiv(m1,m0);
274         }
275 
276         // v1 * m^-1 -> v0
277         template <typename T1, typename T0>
RDiv(const GenVector<T1> & v1,VectorView<T0> v0)278         inline void RDiv(
279             const GenVector<T1>& v1, VectorView<T0> v0) const
280         {
281             TMVAssert(rowsize() == v1.size());
282             TMVAssert(colsize() == v0.size());
283             doRDiv(v1,v0);
284         }
285 
286         template <typename T1, typename T0>
RDiv(const GenMatrix<T1> & m1,MatrixView<T0> m0)287         inline void RDiv(
288             const GenMatrix<T1>& m1, MatrixView<T0> m0) const
289         {
290             TMVAssert(rowsize() == m1.rowsize());
291             TMVAssert(colsize() == m0.rowsize());
292             TMVAssert(m1.colsize() == m0.colsize());
293             doRDiv(m1,m0);
294         }
295 
296         //
297         // Division Control
298         //
299 
300         void divideUsing(DivType dt) const;
301 
302         void divideInPlace() const;
303         void dontDivideInPlace() const;
304         void saveDiv() const;
305         void dontSaveDiv() const;
306 
307         // setDiv is defined in the derived class.
308         virtual void setDiv() const = 0;
309         void unsetDiv() const;
310         void resetDiv() const;
311 
312         DivType getDivType() const;
313         bool divIsInPlace() const;
314         bool divIsSaved() const;
315         bool divIsSet() const;
316 
317         bool checkDecomp(std::ostream* fout=0) const;
318         bool checkDecomp(const BaseMatrix<T>& m2, std::ostream* fout=0) const;
319 
320     protected :
321 
322         void doneDiv() const;
323         const Divider<T>* getDiv() const;
324         void resetDivType() const;
325 
326         // Two more that need to be defined in the derived class:
327         virtual const BaseMatrix<T>& getMatrix() const = 0;
328 
329         mutable auto_ptr<Divider<T> > divider;
330         mutable DivType divtype;
331 
332     private :
333 
334         DivHelper(const DivHelper<T>&);
335         DivHelper<T>& operator=(const DivHelper<T>&);
336 
337         T doDet() const;
338         RT doLogDet(T* sign) const;
339         template <typename T1>
340         void doMakeInverse(MatrixView<T1> minv) const;
341         void doMakeInverseATA(MatrixView<T> minv) const;
342         bool doIsSingular() const;
343         RT doNorm2() const;
344         RT doCondition() const;
345         template <typename T1>
346         void doLDivEq(VectorView<T1> v) const;
347         template <typename T1>
348         void doLDivEq(MatrixView<T1> m) const;
349         template <typename T1>
350         void doRDivEq(VectorView<T1> v) const;
351         template <typename T1>
352         void doRDivEq(MatrixView<T1> m) const;
353         template <typename T1, typename T0>
354         void doLDiv(
355             const GenVector<T1>& v1, VectorView<T0> v0) const;
356         template <typename T1, typename T0>
357         void doLDiv(
358             const GenMatrix<T1>& m1, MatrixView<T0> m0) const;
359         template <typename T1, typename T0>
360         void doRDiv(
361             const GenVector<T1>& v1, VectorView<T0> v0) const;
362         template <typename T1, typename T0>
363         void doRDiv(
364             const GenMatrix<T1>& m1, MatrixView<T0> m0) const;
365 
366     }; // DivHelper
367 
368     //
369     // Functions of Matrices:
370     //
371 
372     template <typename T>
Det(const BaseMatrix<T> & m)373     inline T Det(const BaseMatrix<T>& m)
374     { return m.det(); }
375 
376     template <typename T>
TMV_RealType(T)377     inline TMV_RealType(T) LogDet(const BaseMatrix<T>& m)
378     { return m.logDet(); }
379 
380     template <typename T>
Trace(const BaseMatrix<T> & m)381     inline T Trace(const BaseMatrix<T>& m)
382     { return m.trace(); }
383 
384     template <typename T>
SumElements(const BaseMatrix<T> & m)385     inline T SumElements(const BaseMatrix<T>& m)
386     { return m.sumElements(); }
387 
388     template <typename T>
TMV_RealType(T)389     inline TMV_RealType(T) SumAbsElements(const BaseMatrix<T>& m)
390     { return m.sumAbsElements(); }
391 
392     template <typename T>
TMV_RealType(T)393     inline TMV_RealType(T) SumAbs2Elements(const BaseMatrix<T>& m)
394     { return m.sumAbs2Elements(); }
395 
396     template <typename T>
TMV_RealType(T)397     inline TMV_RealType(T) Norm(const BaseMatrix<T>& m)
398     { return m.norm(); }
399 
400     template <typename T>
TMV_RealType(T)401     inline TMV_RealType(T) NormSq(const BaseMatrix<T>& m)
402     { return m.normSq(); }
403 
404     template <typename T>
TMV_RealType(T)405     inline TMV_RealType(T) NormF(const BaseMatrix<T>& m)
406     { return m.normF(); }
407 
408     template <typename T>
TMV_RealType(T)409     inline TMV_RealType(T) Norm1(const BaseMatrix<T>& m)
410     { return m.norm1(); }
411 
412     template <typename T>
TMV_RealType(T)413     inline TMV_RealType(T) Norm2(const BaseMatrix<T>& m)
414     { return m.norm2(); }
415 
416     template <typename T>
TMV_RealType(T)417     inline TMV_RealType(T) NormInf(const BaseMatrix<T>& m)
418     { return m.normInf(); }
419 
420     template <typename T>
TMV_RealType(T)421     inline TMV_RealType(T) MaxAbsElement(const BaseMatrix<T>& m)
422     { return m.maxAbsElement(); }
423 
424     template <typename T>
TMV_RealType(T)425     inline TMV_RealType(T) MaxAbs2Element(const BaseMatrix<T>& m)
426     { return m.maxAbs2Element(); }
427 
428 
429     //
430     // I/O
431     //
432 
433     template <typename T>
434     inline std::ostream& operator<<(
435         const TMV_Writer& writer, const BaseMatrix<T>& m)
436     { m.write(writer); return writer.getos(); }
437 
438     template <typename T>
439     inline std::ostream& operator<<(
440         std::ostream& os, const BaseMatrix<T>& m)
441     { return os << IOStyle() << m; }
442 
443 
444     template <typename T, int A>
TMV_Text(const Matrix<T,A> &)445     inline std::string TMV_Text(const Matrix<T,A>& )
446     {
447         return std::string("Matrix<") +
448             TMV_Text(T()) + "," + Attrib<A>::text() + ">";
449     }
450     template <typename T>
TMV_Text(const GenMatrix<T> &)451     inline std::string TMV_Text(const GenMatrix<T>& )
452     {
453         return std::string("GenMatrix<") + TMV_Text(T()) + ">";
454     }
455     template <typename T, int A>
TMV_Text(const ConstMatrixView<T,A> &)456     inline std::string TMV_Text(const ConstMatrixView<T,A>& )
457     {
458         return std::string("ConstMatrixView<") +
459             TMV_Text(T()) + "," + Attrib<A>::text() + ">";
460     }
461     template <typename T, int A>
TMV_Text(const MatrixView<T,A> &)462     inline std::string TMV_Text(const MatrixView<T,A>& )
463     {
464         return std::string("MatrixView<") +
465             TMV_Text(T()) + "," + Attrib<A>::text() + ">";
466     }
467 
468 } // namespace tmv
469 
470 #endif
471