1 /*
2   HMat-OSS (HMatrix library, open source software)
3 
4   Copyright (C) 2014-2015 Airbus Group SAS
5 
6   This program is free software; you can redistribute it and/or
7   modify it under the terms of the GNU General Public License
8   as published by the Free Software Foundation; either version 2
9   of the License, or (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19 
20   http://github.com/jeromerobert/hmat-oss
21 */
22 
23 #include "config.h"
24 
25 /*! \file
26   \ingroup HMatrix
27   \brief HMatrix type.
28 */
29 #include <algorithm>
30 #include <list>
31 #include <vector>
32 #include <cstring>
33 
34 #include "h_matrix.hpp"
35 #include "cluster_tree.hpp"
36 #include "admissibility.hpp"
37 #include "data_types.hpp"
38 #include "compression.hpp"
39 #include "recursion.hpp"
40 #include "common/context.hpp"
41 #include "common/my_assert.h"
42 #include "json.hpp"
43 
44 using namespace std;
45 
46 namespace hmat {
47 
48 // The default values below will be overwritten in default_engine.cpp by HMatSettings values
49 template<typename T> bool HMatrix<T>::coarsening = false;
50 template<typename T> bool HMatrix<T>::recompress = false;
51 template<typename T> bool HMatrix<T>::validateNullRowCol = false;
52 template<typename T> bool HMatrix<T>::validateCompression = false;
53 template<typename T> bool HMatrix<T>::validationReRun = false;
54 template<typename T> bool HMatrix<T>::validationDump = false;
55 template<typename T> double HMatrix<T>::validationErrorThreshold = 0;
56 
~HMatrix()57 template<typename T> HMatrix<T>::~HMatrix() {
58   if (isRkMatrix() && rk_) {
59     delete rk_;
60     rk_ = NULL;
61   }
62   if (full_) {
63     delete full_;
64     full_ = NULL;
65   }
66   if(ownRowsClusterTree_)
67       delete rows_;
68   if(ownColsClusterTree_)
69       delete cols_;
70 }
71 
72 
73 template<typename T>
reorderVector(ScalarArray<T> * v,int * indices,int axis)74 void reorderVector(ScalarArray<T>* v, int* indices, int axis) {
75   DECLARE_CONTEXT;
76   if (!indices) return;
77   const int n = axis == 0 ? v->rows : v->cols;
78   // If permutation is identity, do nothing
79   bool identity = true;
80   for (int i = 0; i < n; i++) {
81     if (indices[i] != i) {
82       identity = false;
83       break;
84     }
85   }
86   if (identity) return;
87 
88   if (axis == 0) {
89     Vector<T> tmp(n);
90     for (int col = 0; col < v->cols; col++) {
91       Vector<T> column(*v, col);
92       for (int i = 0; i < n; i++)
93         tmp[i] = column[indices[i]];
94       tmp.copy(&column);
95     }
96   } else {
97     ScalarArray<T> tmp(1, n);
98     for (int row = 0; row < v->rows; row++) {
99       ScalarArray<T> column(*v, row, 1, 0, n);
100       for (int i = 0; i < n; i++)
101         tmp.get(0, i) = column.get(0, indices[i]);
102       tmp.copy(&column);
103     }
104   }
105 }
106 
107 template<typename T>
restoreVectorOrder(ScalarArray<T> * v,int * indices,int axis)108 void restoreVectorOrder(ScalarArray<T>* v, int* indices, int axis) {
109   DECLARE_CONTEXT;
110   if (!indices) return;
111   const int n = axis == 0 ? v->rows : v->cols;
112   // If permutation is identity, do nothing
113   bool identity = true;
114   for (int i = 0; i < n; i++) {
115     if (indices[i] != i) {
116       identity = false;
117       break;
118     }
119   }
120   if (identity) return;
121 
122   if (axis == 0) {
123     Vector<T> tmp(n);
124     for (int col = 0; col < v->cols; col++) {
125       Vector<T> column(*v, col);
126       for (int i = 0; i < n; i++) {
127         tmp[indices[i]] = column[i];
128       }
129       tmp.copy(&column);
130     }
131   } else {
132     ScalarArray<T> tmp(1, n);
133     for (int row = 0; row < v->rows; row++) {
134       ScalarArray<T> column(*v, row, 1, 0, n);
135       for (int i = 0; i < n; i++)
136         tmp.get(0, indices[i]) = column.get(0, i);
137       tmp.copy(&column);
138     }
139   }
140 }
141 
142 template<typename T>
HMatrix(const ClusterTree * _rows,const ClusterTree * _cols,const hmat::MatrixSettings * settings,int _depth,SymmetryFlag symFlag,AdmissibilityCondition * admissibilityCondition)143 HMatrix<T>::HMatrix(const ClusterTree* _rows, const ClusterTree* _cols, const hmat::MatrixSettings * settings,
144                     int _depth, SymmetryFlag symFlag, AdmissibilityCondition * admissibilityCondition)
145   : Tree<HMatrix<T> >(NULL, _depth), RecursionMatrix<T, HMatrix<T> >(),
146     rows_(_rows), cols_(_cols), rk_(NULL),
147     rank_(UNINITIALIZED_BLOCK), approximateRank_(UNINITIALIZED_BLOCK),
148     isUpper(false), isLower(false),
149     isTriUpper(false), isTriLower(false), keepSameRows(true), keepSameCols(true), temporary_(false),
150     ownRowsClusterTree_(false), ownColsClusterTree_(false), localSettings(settings, 1e-4)
151 {
152   if (isVoid())
153     return;
154   const bool lowRank = admissibilityCondition->isLowRank(*rows_, *cols_);
155   if (!split(admissibilityCondition, lowRank, symFlag)) {
156     // If we cannot split, we are on a leaf
157     const bool forceFull = admissibilityCondition->forceFull(*rows_, *cols_);
158     const bool forceRk   = admissibilityCondition->forceRk(*rows_, *cols_);
159     assert(!(forceFull && forceRk));
160     if (forceRk || (lowRank && !forceFull))
161       rk(NULL);
162     else
163       full(NULL);
164     approximateRank_ = admissibilityCondition->getApproximateRank(*(rows_), *(cols_));
165   }
166   assert(!this->isLeaf() || isAssembled());
167 }
168 
169 template<typename T>
split(AdmissibilityCondition * admissibilityCondition,bool lowRank,SymmetryFlag symFlag)170 bool HMatrix<T>::split(AdmissibilityCondition * admissibilityCondition, bool lowRank,
171                       SymmetryFlag symFlag) {
172   assert(rank_ == NONLEAF_BLOCK || rank_ == UNINITIALIZED_BLOCK || (this->isLeaf() && isNull()));
173   // We would like to create a block of matrix in one of the following case:
174   // - rows_->isLeaf() && cols_->isLeaf() : both rows and cols are leaves.
175   // - Block is too small to recurse and compress (for performance)
176   // - Block is compressible and in-place compression is possible
177   // In the other cases, we subdivide.
178   //
179   // FIXME: But in practice this does not work yet, so we stop recursion as soon as either rows
180   // or cols is a leaf.
181   bool stopRecursion = admissibilityCondition->stopRecursion(*rows_, *cols_);
182   bool forceRecursion = admissibilityCondition->forceRecursion(*rows_, *cols_, sizeof(T));
183   assert(!(forceRecursion && stopRecursion));
184   // check we can actually split
185   if ((rows_->isLeaf() && cols_->isLeaf()) || stopRecursion || (lowRank && !forceRecursion))
186     return false;
187   pair<bool, bool> splitRC = admissibilityCondition->splitRowsCols(*rows_, *cols_);
188   assert(splitRC.first || splitRC.second);
189   keepSameRows = !splitRC.first;
190   keepSameCols = !splitRC.second;
191   isLower = (symFlag == kLowerSymmetric ? true : false);
192   for (int i = 0; i < nrChildRow(); ++i) {
193     // Don't recurse on rows if splitRowsCols() told us not to.
194     ClusterTree* rowChild = const_cast<ClusterTree*>((keepSameRows ? rows_ : rows_->getChild(i)));
195     for (int j = 0; j < nrChildCol(); ++j) {
196       // Don't recurse on cols if splitRowsCols() told us not to.
197       ClusterTree* colChild = const_cast<ClusterTree*>((keepSameCols ? cols_ : cols_->getChild(j)));
198       if ((symFlag == kNotSymmetric) || (isUpper && (i <= j)) || (isLower && (i >= j))) {
199         if (!admissibilityCondition->isInert(*rowChild, *colChild)) {
200           // Create child only if not 'inert' (inert = will always be null)
201           this->insertChild(i, j,
202                             new HMatrix<T>(rowChild, colChild, localSettings.global,
203                                            this->depth + 1,
204                                            i == j ? symFlag : kNotSymmetric,
205                                            admissibilityCondition));
206         } else
207           // If 'inert', the child is NULL
208           this->insertChild(i, j, NULL);
209       }
210     }
211   }
212   if(nrChildRow() > 0 && nrChildCol() > 0)
213     rank_ = NONLEAF_BLOCK;
214   return true;
215 }
216 
217 template<typename T>
HMatrix(const hmat::MatrixSettings * settings)218 HMatrix<T>::HMatrix(const hmat::MatrixSettings * settings) :
219     Tree<HMatrix<T> >(NULL), RecursionMatrix<T, HMatrix<T> >(), rows_(NULL), cols_(NULL),
220     rk_(NULL), rank_(UNINITIALIZED_BLOCK), approximateRank_(UNINITIALIZED_BLOCK),
221     isUpper(false), isLower(false), isTriUpper(false), isTriLower(false),
222     keepSameRows(true), keepSameCols(true), temporary_(false), ownRowsClusterTree_(false),
223     ownColsClusterTree_(false), localSettings(settings, -1.0)
224     {}
225 
internalCopy(bool temporary,bool withRowChild,bool withColChild) const226 template<typename T> HMatrix<T> * HMatrix<T>::internalCopy(bool temporary, bool withRowChild, bool withColChild) const {
227     HMatrix<T> * r = new HMatrix<T>(localSettings.global);
228     r->rows_ = rows_;
229     r->cols_ = cols_;
230     r->temporary_ = temporary;
231     r->localSettings.epsilon_ = localSettings.epsilon_;
232     if(withRowChild || withColChild) {
233         // Here, we come from HMatrixHandle<T>::createGemmTemporaryRk()
234         // we want to go 1 level below data (which is an Rk)
235         // so we don't use get(i,j) since data has no children
236         // we dont use this->nrChildRow and this->nrChildCol either, they would return 1
237         // (since 'this' is rows- and cols-admissible, unlike 'r')
238         r->keepSameRows = !withRowChild;
239         r->keepSameCols = !withColChild;
240         for(int i = 0; i < r->nrChildRow(); i++) {
241             for(int j = 0; j < r->nrChildCol(); j++) {
242                 HMatrix<T>* child = new HMatrix<T>(localSettings.global);
243                 child->temporary_ = temporary;
244                 child->rows_ = withRowChild ? rows_->getChild(i) : rows_;
245                 child->cols_ = withColChild ? cols_->getChild(j) : cols_;
246                 child->localSettings.epsilon_ = localSettings.epsilon_;
247                 assert(child->rows_ != NULL);
248                 assert(child->cols_ != NULL);
249                 assert(child->localSettings.epsilon_ > 0);
250                 child->rk(NULL);
251                 r->insertChild(i, j, child);
252             }
253         }
254     }
255     return r;
256 }
257 
internalCopy(const ClusterTree * rows,const ClusterTree * cols) const258 template<typename T> HMatrix<T>* HMatrix<T>::internalCopy(
259         const ClusterTree * rows, const ClusterTree * cols) const {
260     HMatrix<T> * r = new HMatrix<T>(localSettings.global);
261     r->temporary_ = true;
262     r->rows_ = rows;
263     r->cols_ = cols;
264     r->localSettings.epsilon_ = localSettings.epsilon_;
265     return r;
266 }
267 
268 template<typename T>
copyStructure() const269 HMatrix<T>* HMatrix<T>::copyStructure() const {
270   HMatrix<T>* h = internalCopy();
271   h->isUpper = isUpper;
272   h->isLower = isLower;
273   h->isTriUpper = isTriUpper;
274   h->isTriLower = isTriLower;
275   h->keepSameRows = keepSameRows;
276   h->keepSameCols = keepSameCols;
277   h->rank_ = rank_ >= 0 ? 0 : rank_;
278   h->approximateRank_ = approximateRank_;
279   if(!this->isLeaf()){
280     for (int i = 0; i < this->nrChild(); ++i) {
281       if (this->getChild(i)) {
282         h->insertChild(i, this->getChild(i)->copyStructure());
283       }
284       else
285         h->insertChild(i, NULL);
286     }
287   }
288   return h;
289 }
290 
291 template<typename T>
Zero(const HMatrix<T> * o)292 HMatrix<T>* HMatrix<T>::Zero(const HMatrix<T>* o) {
293   // leaves are filled by 0
294   HMatrix<T> *h = o->internalCopy();
295   h->isLower = o->isLower;
296   h->isUpper = o->isUpper;
297   h->isTriUpper = o->isTriUpper;
298   h->isTriLower = o->isTriLower;
299   h->keepSameRows = o->keepSameRows;
300   h->keepSameCols = o->keepSameCols;
301   h->rank_ = o->rank_ >= 0 ? 0 : o->rank_;
302   if (h->rank_==0)
303     h->rk(new RkMatrix<T>(NULL, h->rows(), NULL, h->cols()));
304   h->approximateRank_ = o->approximateRank_;
305   if(!o->isLeaf()){
306     for (int i = 0; i < o->nrChild(); ++i) {
307       if (o->getChild(i)) {
308         h->insertChild(i, HMatrix<T>::Zero(o->getChild(i)));
309         } else
310         h->insertChild(i, NULL);
311     }
312   }
313   return h;
314 }
315 
316 template<typename T>
setClusterTrees(const ClusterTree * rows,const ClusterTree * cols)317 void HMatrix<T>::setClusterTrees(const ClusterTree* rows, const ClusterTree* cols) {
318     rows_ = rows;
319     cols_ = cols;
320     if(isRkMatrix() && rk()) {
321         rk()->rows = &(rows->data);
322         rk()->cols = &(cols->data);
323     } else if(isFullMatrix()) {
324         full()->rows_ = &(rows->data);
325         full()->cols_ = &(cols->data);
326     } else if(!this->isLeaf()) {
327       for (int i = 0; i < nrChildRow(); ++i) {
328         // if rows not admissible, don't recurse on them
329         const ClusterTree* rowChild = (keepSameRows ? rows : rows->me()->getChild(i));
330         for (int j = 0; j < nrChildCol(); ++j) {
331           // if cols not admissible, don't recurse on them
332           const ClusterTree* colChild = (keepSameCols ? cols : cols->me()->getChild(j));
333           if(get(i, j))
334             get(i, j)->setClusterTrees(rowChild, colChild);
335         }
336       }
337     }
338 }
339 
340 template<typename T>
assemble(Assembly<T> & f,const AllocationObserver & ao)341 void HMatrix<T>::assemble(Assembly<T>& f, const AllocationObserver & ao) {
342   if (this->isLeaf()) {
343     // If the leaf is admissible, matrix assembly and compression.
344     // if not we keep the matrix.
345     FullMatrix<T> * m = NULL;
346     RkMatrix<T>* assembledRk = NULL;
347     f.assemble(localSettings, *rows_, *cols_, isRkMatrix(), m, assembledRk, lowRankEpsilon(), ao);
348     HMAT_ASSERT(m == NULL || assembledRk == NULL);
349     if(assembledRk) {
350         assert(isRkMatrix());
351         if(rk_)
352             delete rk_;
353         rk(assembledRk);
354     } else {
355         assert(!isRkMatrix());
356         if(full_)
357             delete full_;
358         full(m);
359     }
360   } else {
361     full_ = NULL;
362     rk_ = NULL;
363     for (int i = 0; i < this->nrChild(); i++) {
364       if (this->getChild(i))
365         this->getChild(i)->assemble(f, ao);
366     }
367     assembledRecurse();
368     if (coarsening)
369       coarsen(RkMatrix<T>::approx.coarseningEpsilon);
370   }
371 }
372 
373 template<typename T>
assembleSymmetric(Assembly<T> & f,HMatrix<T> * upper,bool onlyLower,const AllocationObserver & ao)374 void HMatrix<T>::assembleSymmetric(Assembly<T>& f,
375    HMatrix<T>* upper, bool onlyLower, const AllocationObserver & ao) {
376   if (!onlyLower) {
377     if (!upper){
378       upper = this;
379     }
380     assert(*this->rows() == *upper->cols());
381     assert(*this->cols() == *upper->rows());
382   }
383 
384   if (this->isLeaf()) {
385     // If the leaf is admissible, matrix assembly and compression.
386     // if not we keep the matrix.
387     this->assemble(f, ao);
388     if (isRkMatrix()) {
389       if ((!onlyLower) && (upper != this)) {
390         // Admissible leaf: a matrix represented by AB^t is transposed by exchanging A and B.
391         RkMatrix<T>* newRk = rk()->copy();
392         newRk->transpose();
393         if(upper->isRkMatrix() && upper->rk() != NULL)
394             delete upper->rk();
395         upper->rk(newRk);
396       }
397     } else {
398       if ((!onlyLower) && ( upper != this)) {
399         if(isFullMatrix())
400             upper->full(full()->copyAndTranspose());
401         else
402             upper->full(NULL);
403       }
404     }
405   } else {
406     if (onlyLower) {
407       for (int i = 0; i < nrChildRow(); i++) {
408         for (int j = 0; j < nrChildCol(); j++) {
409           if ((*rows() == *cols()) && (j > i)) {
410             continue;
411           }
412           if (get(i,j))
413             get(i,j)->assembleSymmetric(f, NULL, true, ao);
414         }
415       }
416     } else {
417       if (this == upper) {
418         for (int i = 0; i < nrChildRow(); i++) {
419           for (int j = 0; j <= i; j++) {
420             HMatrix<T> *child = get(i, j);
421             HMatrix<T> *upperChild = get(j, i);
422             assert((child != NULL) == (upperChild != NULL));
423             if (child)
424               child->assembleSymmetric(f, upperChild, false, ao);
425           }
426         }
427       } else {
428         for (int i = 0; i < nrChildRow(); i++) {
429           for (int j = 0; j < nrChildCol(); j++) {
430             HMatrix<T> *child = get(i, j);
431             HMatrix<T> *upperChild = upper->get(j, i);
432             assert((child != NULL) == (upperChild != NULL));
433             if (child)
434               child->assembleSymmetric(f, upperChild, false, ao);
435           }
436         }
437         upper->assembledRecurse();
438         if (coarsening)
439           coarsen(RkMatrix<T>::approx.coarseningEpsilon, upper);
440       }
441     }
442     assembledRecurse();
443   }
444 }
445 
info(hmat_info_t & result)446 template<typename T> void HMatrix<T>::info(hmat_info_t & result) {
447     result.nr_block_clusters++;
448     int r = rows()->size();
449     int c = cols()->size();
450     if(r == 0 || c == 0) {
451         return;
452     } else if(this->isLeaf()) {
453         size_t s = ((size_t)r) * c;
454         result.uncompressed_size += s;
455         if(isRkMatrix()) {
456             size_t mem = rank() * (((size_t)r) + c);
457             result.compressed_size += mem;
458             int dim = result.largest_rk_dim_cols + result.largest_rk_dim_rows;
459             if(rows()->size() + cols()->size() > dim) {
460                 result.largest_rk_dim_cols = c;
461                 result.largest_rk_dim_rows = r;
462             }
463 
464             size_t old_s = ((size_t)result.largest_rk_mem_cols + result.largest_rk_mem_rows) * result.largest_rk_mem_rank;
465             if(mem > old_s) {
466                 result.largest_rk_mem_cols = c;
467                 result.largest_rk_mem_rows = r;
468                 result.largest_rk_mem_rank = rank();
469             }
470             result.rk_count++;
471             result.rk_size += s;
472         } else {
473             result.compressed_size += s;
474             result.full_count ++;
475             result.full_size += s;
476         }
477     } else {
478         for (int i = 0; i < this->nrChild(); i++) {
479             HMatrix<T> *child = this->getChild(i);
480             if (child)
481                 child->info(result);
482         }
483     }
484 }
485 
486 template<typename T>
eval(FullMatrix<T> * result,bool renumber) const487 void HMatrix<T>::eval(FullMatrix<T>* result, bool renumber) const {
488   if (this->isLeaf()) {
489     if (this->isNull()) return;
490     FullMatrix<T> *mat = isRkMatrix() ? rk()->eval() : full();
491     int *rowIndices = rows()->indices() + rows()->offset();
492     int rowCount = rows()->size();
493     int *colIndices = cols()->indices() + cols()->offset();
494     int colCount = cols()->size();
495     if(renumber) {
496       for (int j = 0; j < colCount; j++)
497         for (int i = 0; i < rowCount; i++)
498           result->get(rowIndices[i], colIndices[j]) = mat->get(i, j);
499     } else {
500       for (int j = 0; j < colCount; j++)
501         memcpy(&result->get(rows()->offset(), cols()->offset() + j), &mat->get(0, j), rowCount * sizeof(T));
502     }
503     if (isRkMatrix()) {
504       delete mat;
505     }
506   } else {
507     for (int i = 0; i < this->nrChild(); i++) {
508       if (this->getChild(i)) {
509         this->getChild(i)->eval(result, renumber);
510       }
511     }
512   }
513 }
514 
515 template<typename T>
evalPart(FullMatrix<T> * result,const IndexSet * _rows,const IndexSet * _cols) const516 void HMatrix<T>::evalPart(FullMatrix<T>* result, const IndexSet* _rows,
517                           const IndexSet* _cols) const {
518   if (this->isLeaf()) {
519     if (this->isNull()) return;
520     FullMatrix<T> *mat = isRkMatrix() ? rk()->eval() : full();
521     const int rowOffset = rows()->offset() - _rows->offset();
522     const int rowCount = rows()->size();
523     const int colOffset = cols()->offset() - _cols->offset();
524     const int colCount = cols()->size();
525     for (int j = 0; j < colCount; j++) {
526       memcpy(&result->get(rowOffset, j + colOffset), &mat->get(0, j), rowCount * sizeof(T));
527     }
528     if (isRkMatrix()) {
529       delete mat;
530     }
531   } else {
532     for (int i = 0; i < this->nrChild(); i++) {
533       if (this->getChild(i)) {
534         this->getChild(i)->evalPart(result, _rows, _cols);
535       }
536     }
537   }
538 }
539 
normSqr() const540 template<typename T> double HMatrix<T>::normSqr() const {
541   double result = 0.;
542   if (rows()->size() == 0 || cols()->size() == 0) {
543     return result;
544   }
545   if (this->isLeaf() && isAssembled() && !isNull()) {
546     if (isRkMatrix()) {
547       result = rk()->normSqr();
548     } else {
549       result = full()->normSqr();
550     }
551   } else if(!this->isLeaf()){
552     for (int i = 0; i < this->nrChild(); i++) {
553       const HMatrix<T> *res=this->getChild(i);
554       if (res) {
555         // When computing the norm of symmetric matrices, extra-diagonal blocks count twice
556         double coeff = (isUpper || isLower) && ! (*res->rows() == *res->cols()) ? 2. : 1. ;
557         result += coeff * res->normSqr();
558       }
559     }
560   }
561   return result;
562 }
563 
564 // Return an approximation of the largest eigenvalue via the power method.
565 // If needed, we could also return the corresponding eigenvector.
approximateLargestEigenvalue(int max_iter,double epsilon) const566 template<typename T> T HMatrix<T>::approximateLargestEigenvalue(int max_iter, double epsilon) const {
567   if (max_iter <= 0) return 0.0;
568   if (rows()->size() == 0 || cols()->size() == 0) {
569     return 0.0;
570   }
571   const int nrow = rows()->size();
572   Vector<T>  xv(nrow);
573   Vector<T>  xv1(nrow);
574   Vector<T>* x  = &xv;
575   Vector<T>* x1 = &xv1;
576   T ev = Constants<T>::zero;
577   for (int i = 0; i < nrow; i++)
578     xv[i] = static_cast<T>(rand()/(double)RAND_MAX);
579   double normx = x->norm();
580   if (normx == 0.0)
581     return approximateLargestEigenvalue(max_iter - 1, epsilon);
582   x->scale(static_cast<T>(1.0/normx));
583   int iter = 0;
584   double aev = 0.0;
585   double aev_p = 0.0;
586   do {
587     // old eigenvalue
588     aev_p = aev;
589     // Compute x(k+1) = A x(k)
590     //        ev(k+1) = <x(k+1),x(k)>
591     //         x(k+1) = x(k+1) / ||x(k+1)||
592     gemv('N', Constants<T>::pone, x, Constants<T>::zero, x1);
593     ev = Vector<T>::dot(x,x1);
594     // new abs(ev)
595     aev = std::abs(ev);
596     normx = x1->norm();
597     // If x1 is null, restart so that starting point is different.
598     // Decrease max_iter to prevent infinite recursion.
599     if (normx == 0.0)
600       return approximateLargestEigenvalue(max_iter - 1, epsilon);
601     x1->scale(static_cast<T>(1.0/normx));
602     std::swap(x,x1);
603     iter++;
604   } while(iter < max_iter && std::abs(aev - aev_p) > epsilon * aev);
605   return ev;
606 }
607 
608 
609 template<typename T>
scale(T alpha)610 void HMatrix<T>::scale(T alpha) {
611   if(alpha == Constants<T>::zero) {
612     this->clear();
613   } else if(alpha == Constants<T>::pone) {
614     return;
615   } else if (this->isLeaf()) {
616     if (isNull()) {
617       // nothing to do
618     } else if (isRkMatrix()) {
619       rk()->scale(alpha);
620     } else {
621       assert(isFullMatrix());
622       full()->scale(alpha);
623     }
624   } else {
625     for (int i = 0; i < this->nrChild(); i++) {
626       if (this->getChild(i)) {
627         this->getChild(i)->scale(alpha);
628       }
629     }
630   }
631 }
632 
633 template<typename T>
coarsen(double epsilon,HMatrix<T> * upper,bool force)634 bool HMatrix<T>::coarsen(double epsilon, HMatrix<T>* upper, bool force) {
635   // If all children are Rk leaves, then we try to merge them into a single Rk-leaf.
636   // This is done if the memory of the resulting leaf is less than the sum of the initial
637   // leaves. Note that this operation could be used hierarchically.
638 
639   bool allRkLeaves = true;
640   const RkMatrix<T>* childrenArray[this->nrChild()];
641   size_t childrenElements = 0;
642   for (int i = 0; i < this->nrChild(); i++) {
643     childrenArray[i] = nullptr;
644     HMatrix<T> *child = this->getChild(i);
645     if (!child) continue;
646     if (!child->isRkMatrix()) {
647       allRkLeaves = false;
648       break;
649     } else {
650       childrenArray[i] = child->rk();
651       if(childrenArray[i])
652         childrenElements += (childrenArray[i]->rows->size()
653                            + childrenArray[i]->cols->size()) * childrenArray[i]->rank();
654     }
655   }
656   if (allRkLeaves) {
657     std::vector<T> alpha(this->nrChild(), Constants<T>::pone);
658     RkMatrix<T> * candidate = new RkMatrix<T>(NULL, rows(), NULL, cols());
659     candidate->formattedAddParts(epsilon, &alpha[0], childrenArray, this->nrChild());
660     size_t elements = (((size_t) candidate->rows->size()) + candidate->cols->size()) * candidate->rank();
661     if (force || elements < childrenElements) {
662       // Replace 'this' by the new Rk matrix
663       for (int i = 0; i < this->nrChild(); i++)
664         this->removeChild(i);
665       this->children.clear();
666       rk(candidate);
667       assert(this->isLeaf());
668       assert(isRkMatrix());
669       // If necessary, replace 'upper' by the new Rk matrix transposed (exchange a and b)
670       if (upper) {
671         for (int i = 0; i < this->nrChild(); i++)
672           upper->removeChild(i);
673         upper->children.clear();
674         RkMatrix<T>* newRk = candidate->copy();
675         newRk->transpose();
676         upper->rk(newRk);
677         assert(upper->isLeaf());
678         assert(upper->isRkMatrix());
679       }
680     } else {
681       delete candidate;
682     }
683   }
684 
685   return allRkLeaves;
686 }
687 
getChildForGEMM(char & t,int i,int j) const688 template<typename T> const HMatrix<T> * HMatrix<T>::getChildForGEMM(char & t, int i, int j) const {
689   // At most 1 of these flags must be 'true'
690   assert(isUpper + isLower + isTriUpper + isTriLower >= -1);
691   assert(!this->isLeaf());
692 
693   const HMatrix<T>* res;
694   if(t != 'N')
695     std::swap(i,j);
696   if( (isLower && j > i) ||
697       (isUpper && i > j) ) {
698     res = get(j, i);
699     t = t == 'N' ? 'T' : 'N';
700   } else {
701     res = get(i, j);
702   }
703   return res;
704 }
705 
706 // y <- alpha * op(this) * x + beta * y or y <- alpha * x * op(this) + beta * y.
707 template<typename T>
gemv(char matTrans,T alpha,const ScalarArray<T> * x,T beta,ScalarArray<T> * y,Side side) const708 void HMatrix<T>::gemv(char matTrans, T alpha, const ScalarArray<T>* x, T beta, ScalarArray<T>* y, Side side) const {
709   if (rows()->size() == 0 || cols()->size() == 0) return;
710   // The dimensions of the H-matrix and the 2 ScalarArrays must match exactly
711   if(side == Side::LEFT) {
712     // Y + H * X
713     assert(x->cols == y->cols);
714     assert((matTrans != 'N' ? cols()->size() : rows()->size()) == y->rows);
715     assert((matTrans != 'N' ? rows()->size() : cols()->size()) == x->rows);
716   } else {
717     // Y + X * H
718     assert(x->rows == y->rows);
719     assert((matTrans != 'N' ? cols()->size() : rows()->size()) == x->cols);
720     assert((matTrans != 'N' ? rows()->size() : cols()->size()) == y->cols);
721   }
722   if (beta != Constants<T>::pone) {
723     y->scale(beta);
724   }
725 
726   if (!this->isLeaf()) {
727     for (int i = 0, iend = (matTrans=='N' ? nrChildRow() : nrChildCol()); i < iend; i++)
728       for (int j = 0, jend = (matTrans=='N' ? nrChildCol() : nrChildRow()); j < jend; j++) {
729         char trans = matTrans;
730         // trans(child) = the child (i,j) of matTrans(this)
731         const HMatrix<T>* child = getChildForGEMM(trans, i, j);
732         if (child) {
733 
734           // I get the rows and cols info of 'child'
735           int colsOffset = child->cols()->offset() - cols()->offset();
736           int colsSize   = child->cols()->size();
737           int rowsOffset = child->rows()->offset() - rows()->offset();
738           int rowsSize   = child->rows()->size();
739 
740           // swap if needed to get the info for trans(child)
741           if (trans != 'N') {
742             std::swap(colsOffset, rowsOffset);
743             std::swap(colsSize,   rowsSize);
744           }
745 
746           if (side == Side::LEFT) {
747             // get the rows subset of X aligned with 'trans(child)' cols and Y aligned with 'trans(child)' rows
748             const ScalarArray<T> subX(*x, colsOffset, colsSize, 0, x->cols);
749             ScalarArray<T> subY(*y, rowsOffset, rowsSize, 0, y->cols);
750             child->gemv(trans, alpha, &subX, Constants<T>::pone, &subY, side);
751           } else {
752             // get the columns subset of X aligned with 'trans(child)' rows and Y aligned with 'trans(child)' columns
753             const ScalarArray<T> subX(*x, 0, x->rows, rowsOffset, rowsSize);
754             ScalarArray<T> subY(*y, 0, y->rows, colsOffset, colsSize);
755             child->gemv(trans, alpha, &subX, Constants<T>::pone, &subY, side);
756           }
757         }
758         else continue;
759       }
760 
761   } else {
762     // We are on a leaf of the matrix 'this'
763     if (isFullMatrix()) {
764       if (side == Side::LEFT) {
765         y->gemm(matTrans, 'N', alpha, &full()->data, x, Constants<T>::pone);
766       } else {
767         y->gemm('N', matTrans, alpha, x, &full()->data, Constants<T>::pone);
768       }
769     } else if(!isNull()){
770       rk()->gemv(matTrans, alpha, x, Constants<T>::pone, y, side);
771     }
772   }
773 }
774 
775 template<typename T>
gemv(char matTrans,T alpha,const FullMatrix<T> * x,T beta,FullMatrix<T> * y,Side side) const776 void HMatrix<T>::gemv(char matTrans, T alpha, const FullMatrix<T>* x, T beta, FullMatrix<T>* y, Side side) const {
777   gemv(matTrans, alpha, &x->data, beta, &y->data, side);
778 }
779 
780 /**
781  * @brief List all Rk matrice in the m matrice.
782  * @return true if the matrix contains only rk matrices, fall if it contains
783  * both rk and full matrices
784  */
listAllRk(const HMatrix<T> * m,vector<const RkMatrix<T> * > & result)785 template<typename T> bool listAllRk(const HMatrix<T> * m, vector<const RkMatrix<T>*> & result) {
786     if(m == NULL) {
787         // do nothing
788     } else if(m->isRkMatrix())
789         result.push_back(m->rk());
790     else if(m->isLeaf())
791         return false;
792     else {
793         for(int i = 0; i < m->nrChild(); i++) {
794             if(m->getChild(i) && !listAllRk(m->getChild(i), result))
795                 return false;
796         }
797     }
798     return true;
799 }
800 
801 /**
802  * @brief generic AXPY implementation that dispatch to others or recurse
803  */
axpy(T alpha,const HMatrix<T> * x)804 template <typename T> void HMatrix<T>::axpy(T alpha, const HMatrix<T> *x) {
805   if (x->isLeaf()) {
806     if (x->isNull()) {
807       // nothing to do
808     } else if (x->isFullMatrix())
809       axpy(alpha, x->full());
810     else if (x->isRkMatrix())
811       axpy(alpha, x->rk());
812   } else {
813     HMAT_ASSERT(*rows() == *x->rows());
814     HMAT_ASSERT(*cols() == *x->cols());
815     if (this->isLeaf()) {
816       if (isRkMatrix()) {
817         if (!rk())
818           rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
819         vector<const RkMatrix<T> *> rkLeaves;
820         if (listAllRk(x, rkLeaves)) {
821           vector<T> alphas(rkLeaves.size(), alpha);
822           rk()->formattedAddParts(lowRankEpsilon(), &alphas[0], &rkLeaves[0], rkLeaves.size());
823         } else {
824           // x has contains both full and Rk matrices, this is not
825           // supported yet.
826           HMAT_ASSERT(false);
827         }
828         rank_ = rk()->rank();
829       } else {
830         if (full() == NULL)
831           full(new FullMatrix<T>(rows(), cols()));
832         FullMatrix<T> xFull(x->rows(), x->cols());
833         x->evalPart(&xFull, x->rows(), x->cols());
834         full()->axpy(alpha, &xFull);
835       }
836     } else
837       for (int i = 0; i < this->nrChild(); i++) {
838         HMatrix<T> *child = this->getChild(i);
839         const HMatrix<T> *bChild = x->isLeaf() ? x : x->getChild(i);
840         if (bChild != NULL) {
841           HMAT_ASSERT(child != NULL); // This may happen but this is not supported yet
842           child->axpy(alpha, bChild);
843         }
844       }
845   }
846 }
847 
848 /** @brief AXPY between 'this' an H matrix and a subset of B with B a RkMatrix */
849 template<typename T>
axpy(T alpha,const RkMatrix<T> * b)850 void HMatrix<T>::axpy(T alpha, const RkMatrix<T>* b) {
851   DECLARE_CONTEXT;
852   // this += alpha * b
853   assert(b);
854   assert(b->rows->intersects(*rows()));
855   assert(b->cols->intersects(*cols()));
856 
857   if (b->rank() == 0 || rows()->size() == 0 || cols()->size() == 0) {
858     return;
859   }
860 
861   // If 'this' is not a leaf, we recurse with the same 'b'
862   if (!this->isLeaf()) {
863     for (int i = 0; i < this->nrChild(); i++) {
864       HMatrix<T> * c = this->getChild(i);
865       if (c) {
866         if(b->rank() < std::min(c->rows()->size(), c->cols()->size()) && b->rank() > 10) {
867           RkMatrix<T> * bc = b->truncatedSubset(c->rows(), c->cols(), c->lowRankEpsilon());
868           c->axpy(alpha, bc);
869           delete bc;
870         }
871         else c->axpy(alpha, b);
872       }
873     }
874   } else {
875     // To add-up a leaf to a RkMatrix, resizing may be necessary.
876     bool needResizing = b->rows->isStrictSuperSet(*rows())
877       || b->cols->isStrictSuperSet(*cols());
878     const RkMatrix<T>* newRk = b;
879     if (needResizing) {
880       newRk = b->subset(rows(), cols());
881     }
882     if (isRkMatrix()) {
883       if(!rk())
884           rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
885       rk()->axpy(lowRankEpsilon(), alpha, newRk);
886       rank_ = rk()->rank();
887     } else {
888       // In this case, the matrix has small size
889       // then evaluating the Rk-matrix is cheaper
890       FullMatrix<T>* rkMat = newRk->eval();
891       if(isFullMatrix()) {
892         full()->axpy(alpha, rkMat);
893         delete rkMat;
894       } else {
895         rkMat->scale(alpha);
896         full(rkMat);
897       }
898     }
899     if (needResizing) {
900       delete newRk;
901     }
902   }
903 }
904 
905 /** @brief AXPY between 'this' an H matrix and a subset of B with B a FullMatrix */
906 template<typename T>
axpy(T alpha,const FullMatrix<T> * b)907 void HMatrix<T>::axpy(T alpha, const FullMatrix<T>* b) {
908   DECLARE_CONTEXT;
909   // this += alpha * b
910   assert(b->rows_->isSuperSet(*this->rows()) && b->cols_->isSuperSet(*this->cols()));
911 
912   // If 'this' is not a leaf, we recurse with the same 'b'
913   if (!this->isLeaf()) {
914     for (int i = 0; i < this->nrChild(); i++) {
915       HMatrix<T>* child = this->getChild(i);
916       if (child)
917         child->axpy(alpha, b);
918     }
919   } else {
920     const FullMatrix<T>* subMat = b->subset(rows(), cols());
921     if (isRkMatrix()) {
922       if(!rk())
923         rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
924       rk()->axpy(lowRankEpsilon(), alpha, subMat);
925       rank_ = rk()->rank();
926     } else if(isFullMatrix()){
927        full()->axpy(alpha, subMat);
928     } else {
929        assert(!isAssembled() || full() == NULL);
930        full(subMat->copy());
931        if(alpha != Constants<T>::pone)
932          full()->scale(alpha);
933     }
934     delete subMat;
935   }
936 }
937 
938 template<typename T>
addIdentity(T alpha)939 void HMatrix<T>::addIdentity(T alpha)
940 {
941   if (this->isLeaf()) {
942     if (isNull()) {
943       HMAT_ASSERT(!this->isRkMatrix());
944       full(new FullMatrix<T>(rows(), cols()));
945     }
946     if (isFullMatrix()) {
947       FullMatrix<T> * b = full();
948       assert(b->rows() == b->cols());
949       for (int i = 0; i < b->rows(); i++) {
950           b->get(i, i) += alpha;
951       }
952     } else {
953       HMAT_ASSERT(false);
954     }
955   } else {
956     for (int i = 0; i < nrChildRow(); i++)
957       if(get(i, i) != nullptr)
958         get(i,i)->addIdentity(alpha);
959   }
960 }
961 
962 template<typename T>
addRand(double epsilon)963 void HMatrix<T>::addRand(double epsilon)
964 {
965   if (this->isLeaf()) {
966     if (isFullMatrix()) {
967       full()->addRand(epsilon);
968     } else {
969       rk()->addRand(epsilon);
970     }
971   } else {
972     for (int i = 0; i < nrChildRow(); i++) {
973       for(int j = 0; j < nrChildCol(); j++) {
974 	if(get(i,j)) {
975           get(i,j)->addRand(epsilon);
976         }
977       }
978     }
979   }
980 }
981 
subset(const IndexSet * rows,const IndexSet * cols) const982 template<typename T> HMatrix<T> * HMatrix<T>::subset(
983     const IndexSet * rows, const IndexSet * cols) const
984 {
985     if((this->rows() == rows && this->cols() == cols) ||
986        (*(this->rows()) == *rows && *(this->cols()) == *cols) ||
987        (!rows->isSubset(*(this->rows())) || !cols->isSubset(*(this->cols())))) // TODO cette ligne me parait louche... si rows et cols sont pas bons, on renvoie 'this' sans meme se plaindre ???
988         return const_cast<HMatrix<T>*>(this);
989 
990     // this could be implemented but if you need it you more
991     // likely have something to fix at a higher level.
992     assert(!this->isNull());
993 
994     if(this->isLeaf()) {
995         HMatrix<T> * tmpMatrix = new HMatrix<T>(this->localSettings.global);
996         tmpMatrix->temporary_=true;
997         tmpMatrix->localSettings.epsilon_ = localSettings.epsilon_;
998         ClusterTree * r = rows_->slice(rows->offset(), rows->size());
999         ClusterTree * c = cols_->slice(cols->offset(), cols->size());
1000 
1001         // ensure the cluster tree are properly freed
1002         r->father = r;
1003         c->father = c;
1004 
1005         tmpMatrix->rows_ = r;
1006         tmpMatrix->cols_ = c;
1007         tmpMatrix->ownClusterTrees(true, true);
1008         if(this->isRkMatrix()) {
1009           tmpMatrix->rk(const_cast<RkMatrix<T>*>(rk()->subset(tmpMatrix->rows(), tmpMatrix->cols())));
1010         } else {
1011           tmpMatrix->full(const_cast<FullMatrix<T>*>(full()->subset(tmpMatrix->rows(), tmpMatrix->cols())));
1012         }
1013         return tmpMatrix;
1014     } else {
1015         // 'This' is not a leaf
1016         //TODO not yet implemented but should not happen
1017         HMAT_ASSERT(false);
1018     }
1019 }
1020 
1021 /**
1022  * @brief Ensure that matrices have compatible cluster trees.
1023  * @param row_a If true check the number of row of A is compatible else check columns
1024  * @param row_b If true check the number of row of B is compatible else check columns
1025  * @param in_a The input A matrix whose dimension must be checked
1026  * @param in_b The input B matrix whose dimension must be checked
1027  * @param out_a A subset of the A matrix which have compatible dimension with out_b.
1028  *  out_a is a view on in_a, no data are copied. It can possibly return in_a if matrices
1029  *  are already compatibles.
1030  * @param out_b A subset of the B matrix which have compatible dimension with out_a.
1031  *  out_b is a view on in_b, no data are copied. It can possibly return in_b if matrices
1032  *  are already compatibles.
1033  */
1034 template<typename T> void
makeCompatible(bool row_a,bool row_b,const HMatrix<T> * in_a,const HMatrix<T> * in_b,HMatrix<T> * & out_a,HMatrix<T> * & out_b)1035 makeCompatible(bool row_a, bool row_b,
1036                const HMatrix<T> * in_a, const HMatrix<T> * in_b,
1037                HMatrix<T> * & out_a, HMatrix<T> * & out_b) {
1038 
1039     // suppose that A is bigger than B: in that case A will change, not B
1040     const IndexSet * cdb = row_b ? in_b->rows() : in_b->cols();
1041     if(row_a) // restrict the rows of in_a to cdb
1042         out_a = in_a->subset(cdb, in_a->cols());
1043     else // or the cols
1044         out_a = in_a->subset(in_a->rows(), cdb);
1045 
1046     // if A has changed, B won't change so we bypass this second step
1047     if(out_a == in_a) {
1048         // suppose than B is bigger than A: B will change, not A
1049         const IndexSet * cda = row_a ? in_a->rows() : in_a->cols();
1050         if(row_b)
1051             out_b = in_b->subset(cda, in_b->cols());
1052         else
1053             out_b = in_b->subset(in_b->rows(), cda);
1054     }
1055     else
1056         out_b = const_cast<HMatrix<T> *>(in_b);
1057 }
1058 
1059 /**
1060  * @brief A GEMM implementation which do not require matrices have compatible
1061  * cluster tree.
1062  *
1063  *  We compute the product alpha.f(a).f(b)+c -> c (with c=this)
1064  *  f(a)=transpose(a) if transA='T', f(a)=a if transA='N' (idem for b)
1065  */
uncompatibleGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1066 template<typename T> void HMatrix<T>::uncompatibleGemm(char transA, char transB, T alpha,
1067                                                   const HMatrix<T>* a, const HMatrix<T>* b) {
1068     // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1069     if(isVoid() || a->isVoid())
1070         return;
1071     HMatrix<T> * va = NULL;
1072     HMatrix<T> * vb = NULL;
1073     HMatrix<T> * vc = NULL;;
1074     HMatrix<T> * vva = NULL;
1075     HMatrix<T> * vvb = NULL;
1076     HMatrix<T> * vvc = NULL;
1077 
1078     // Create va & vb = the subsets of a & b that match each other for doing the product f(a).f(b)
1079     // We modify the columns of f(a) and the rows of f(b)
1080     makeCompatible<T>(transA != 'N', transB == 'N', a, b, va, vb);
1081 
1082     if(this->isLeaf() && !this->isRkMatrix() && this->full() == NULL) {
1083 	  // C (this) is a null full block. We cannot get the subset of it and we
1084 	  // don't know yet if we need to allocate it
1085 	  fullHHGemm(this, transA, transB, alpha, va, vb);
1086       if(va != a)
1087         delete va;
1088       if(vb != b)
1089         delete vb;
1090       return;
1091     } else {
1092       // Create vva & vc = the subsets of va & c (=this) that match each other for doing the sum c+f(a).f(b)
1093       // We modify the rows of f(a) and the rows of c
1094       makeCompatible<T>(transA == 'N', true, va, this, vva, vc);
1095 
1096       // Create vvb & vvc = the subsets of vb & vc that match each other for doing the sum c+f(a).f(b)
1097       // We modify the columns of f(b) and the columns of c
1098       makeCompatible<T>(transB != 'N', false, vb, vc, vvb, vvc);
1099     }
1100 
1101     // Delete the intermediate matrices, except if subset() in makecompatible() has returned the original matrix
1102     if(va != vva && va != a)
1103         delete va;
1104     if(vb != vvb && vb != b)
1105         delete vb;
1106     if(vc != vvc && vc != this)
1107         delete vc;
1108 
1109     // writing on a subset of an RkMatrix is not possible without
1110     // modifying the whole matrix
1111     assert(!isRkMatrix() || vvc == this);
1112     // Do the product on the matrices that are now compatible
1113     vvc->leafGemm(transA, transB, alpha, vva, vvb);
1114 
1115     // Delete the temporary matrices
1116     if(vva != a)
1117         delete vva;
1118     if(vvb != b)
1119         delete vvb;
1120     if(vvc != this)
1121         delete vvc;
1122 }
1123 
1124 template<typename T>
compatibilityGridForGEMM(const HMatrix<T> * a,Axis axisA,char transA,const HMatrix<T> * b,Axis axisB,char transB)1125 unsigned char * compatibilityGridForGEMM(const HMatrix<T>* a, Axis axisA, char transA, const HMatrix<T>* b, Axis axisB, char transB) {
1126     // Let us first consider C = A^T * B where A, B and C are top-level matrices:
1127     //  [ C11 | C12 ]   [ A11^T | A21^T ]   [ B11 | B12 ]
1128     //  [ ----+---- ] = [ ------+------ ] * [ ----+---- ]
1129     //  [ C21 | C22 ]   [ A12^T | A22^T ]   [ B21 | B22 ]
1130     // This multiplication is possible only if columns of A^T and rows of B are the same,
1131     // rows of A^T and rows of C are the same, and columns of B and columns of C are the same.
1132     // Matrices are built from the same cluster trees, so we know that blocks are split at the
1133     // same place, and this function will return:
1134     //    compatibilityGridForGEMM(A, COL, 'T', B, ROW, 'N') = {1, 0, 0, 1}
1135     //    compatibilityGridForGEMM(A, ROW, 'T', C, ROW, 'N') = {1, 0, 0, 1}
1136     //    compatibilityGridForGEMM(B, COL, 'N', C, ROW, 'N') = {1, 0, 0, 1}
1137     // But blocks could be split in a single direction instead of 2, or could not
1138     // be split; if A had been split only by rows, we would have
1139     //  [ C11 | C12 ]                       [ B11 | B12 ]
1140     //  [ ----+---- ] = [ A11^T | A21^T ] * [ ----+---- ]
1141     //  [ C21 | C22 ]                       [ B21 | B22 ]
1142     //    compatibilityGridForGEMM(A, ROW, 'T', C, ROW, 'N') = {1, 1}
1143     //
1144     // Situation is much more complicated when considering inner nodes; for instance let us
1145     // have a look at A11^T * B11 with A11 and B11 being defined just above. Rows of A11^T
1146     // are equal to rows(C11)+rows(C21), and thus (considering that A11 and C11 are split
1147     // into 4 nodes)
1148     //    compatibilityGridForGEMM(A11, ROW, 'T', C11, ROW, 'N') = {1, 1, 0, 0}
1149     //
1150     // This function is generic, it works for all cases as long as matrices are built from
1151     // the same cluster trees.
1152 
1153     int row_a = transA == 'N' ? a->nrChildRow() : a->nrChildCol();
1154     int col_a = transA == 'N' ? a->nrChildCol() : a->nrChildRow();
1155     int row_b = transB == 'N' ? b->nrChildRow() : b->nrChildCol();
1156     int col_b = transB == 'N' ? b->nrChildCol() : b->nrChildRow();
1157     size_t nr_blocks = (axisA == Axis::ROW ? row_a : col_a) * (axisB == Axis::ROW ? row_b : col_b);
1158     unsigned char * result = new unsigned char[nr_blocks];
1159     memset(result, 0, nr_blocks);
1160 
1161     if (axisA == Axis::ROW) {
1162         for (int iA = 0; iA < row_a; iA++) {
1163             // All children on a row have the same row cluster tree, get it
1164             // from the first non null child.  We also consider the case where
1165             // 'a' is a leaf.
1166             const HMatrix<T> *childA = a->isLeaf() ? a : nullptr;
1167             char tA = transA;
1168             for (int jA = 0; !childA && jA < col_a; jA++) {
1169               tA = transA;
1170               childA = a->getChildForGEMM(tA, iA, jA);
1171             }
1172             // This row contains only null children, skip it
1173             if(!childA)
1174                 continue;
1175             if (axisB == Axis::ROW) {
1176                 for (int iB = 0; iB < row_b; iB++) {
1177                     for (int jB = 0; jB < col_b; jB++) {
1178                         char tB = transB;
1179                         const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1180                         if(childB) {
1181                             result[iA * row_b + iB] = (tA == 'N' ? childA->rows() : childA->cols())->intersects(*(tB == 'N' ? childB->rows() : childB->cols()));
1182                             break;
1183                         }
1184                     }
1185                 }
1186             } else {
1187                 for (int jB = 0; jB < col_b; jB++) {
1188                     for (int iB = 0; iB < row_b; iB++) {
1189                         char tB = transB;
1190                         const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1191                         if(childB) {
1192                             result[iA * col_b + jB] = (tA == 'N' ? childA->rows() : childA->cols())->intersects(*(tB == 'N' ? childB->cols() : childB->rows()));
1193                             break;
1194                         }
1195                     }
1196                 }
1197             }
1198         }
1199     } else {
1200         for (int jA = 0; jA < col_a; jA++) {
1201             const HMatrix<T> *childA = a->isLeaf() ? a : nullptr;
1202             char tA = transA;
1203             for (int iA = 0; !childA && iA < row_a; iA++) {
1204               tA = transA;
1205               childA = a->getChildForGEMM(tA, iA, jA);
1206             }
1207             // This column contains only null children, skip it
1208             if(!childA)
1209                 continue;
1210             if (axisB == Axis::ROW) {
1211                 for (int iB = 0; iB < row_b; iB++) {
1212                     for (int jB = 0; jB < col_b; jB++) {
1213                         char tB = transB;
1214                         const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1215                         if(childB) {
1216                             result[jA * row_b + iB] = (tA == 'N' ? childA->cols() : childA->rows())->intersects(*(tB == 'N' ? childB->rows() : childB->cols()));
1217                             break;
1218                         }
1219                     }
1220                 }
1221             } else {
1222                 for (int jB = 0; jB < col_b; jB++) {
1223                     for (int iB = 0; iB < row_b; iB++) {
1224                         char tB = transB;
1225                         const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1226                         if(childB) {
1227                             result[jA * col_b + jB] = (tA == 'N' ? childA->cols() : childA->rows())->intersects(*(tB == 'N' ? childB->cols() : childB->rows()));
1228                             break;
1229                         }
1230                     }
1231                 }
1232             }
1233         }
1234     }
1235     return result;
1236 }
1237 
1238 template<typename T> void
recursiveGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1239 HMatrix<T>::recursiveGemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1240     // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1241     if(isVoid() || a->isVoid())
1242         return;
1243 
1244     // None of the matrices is a leaf
1245     if (!this->isLeaf() && !a->isLeaf() && !b->isLeaf()) {
1246         int row_a = transA == 'N' ? a->nrChildRow() : a->nrChildCol();
1247         int col_a = transA == 'N' ? a->nrChildCol() : a->nrChildRow();
1248         int row_b = transB == 'N' ? b->nrChildRow() : b->nrChildCol();
1249         int col_b = transB == 'N' ? b->nrChildCol() : b->nrChildRow();
1250         int row_c = nrChildRow();
1251         int col_c = nrChildCol();
1252 
1253         // There are 6 nested loops, this may be an issue if there are more
1254         // than 2 children in each direction; precompute compatibility between
1255         // blocks to improve performance:
1256         //   + columns of a and rows of b
1257         unsigned char * is_compatible_a_b = compatibilityGridForGEMM(a, Axis::COL, transA, b, Axis::ROW, transB);
1258         //   + rows of a and rows of c
1259         unsigned char * is_compatible_a_c = compatibilityGridForGEMM(a, Axis::ROW, transA, this, Axis::ROW, 'N');
1260         //   + columns of b and columns of c
1261         unsigned char * is_compatible_b_c = compatibilityGridForGEMM(b, Axis::COL, transB, this, Axis::COL, 'N');
1262         //  With these arrays, we can exit early from loops on iA, jB and l
1263         //  when blocks are not compatible, and thus there are only 3 real
1264         //  loops (on i, j, k) and performance penalty should be negligible.
1265         for (int i = 0; i < row_c; i++) {
1266             for (int j = 0; j < col_c; j++) {
1267                 HMatrix<T>* child = get(i, j);
1268                 if (!child) { // symmetric/triangular case or empty block coming from symbolic factorisation of sparse matrices
1269                     continue;
1270                 }
1271 
1272                 for (int iA = 0; iA < row_a; iA++) {
1273                   if (!is_compatible_a_c[iA * row_c + i])
1274                     continue;
1275                   for (int jB = 0; jB < col_b; jB++) {
1276                     if (!is_compatible_b_c[jB * col_c + j])
1277                       continue;
1278                     for (int k = 0; k < col_a; k++) {
1279                       char tA = transA;
1280                       const HMatrix<T> * childA = a->getChildForGEMM(tA, iA, k);
1281                       if(!childA)
1282                         continue;
1283                       for (int l = 0; l < row_b; l++) {
1284                         if (!is_compatible_a_b[k * row_b + l])
1285                           continue;
1286                         char tB = transB;
1287                         const HMatrix<T> * childB = b->getChildForGEMM(tB, l, jB);
1288                         if(childB)
1289                           child->gemm(tA, tB, alpha, childA, childB, Constants<T>::pone);
1290                       }
1291                     }
1292                   }
1293                 }
1294             }
1295         }
1296         delete [] is_compatible_a_b;
1297         delete [] is_compatible_a_c;
1298         delete [] is_compatible_b_c;
1299     } // if (!this->isLeaf() && !a->isLeaf() && !b->isLeaf())
1300     else
1301         uncompatibleGemm(transA, transB, alpha, a, b);
1302 }
1303 
1304 /**
1305  * @brief product of 2 H-matrix to a full block.
1306  *
1307  * Full blocks may be null and it's that case it's not possible to take
1308  * a subset from them. We don't want to allocate them before the recursion
1309  * because the recursion may show that allocation is not needed. This function
1310  * allow to recurse and bypass the uncompatibleGemm method which would create
1311  * invalid subset.
1312  */
fullHHGemm(HMatrix<T> * c,char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1313 template<typename T> void fullHHGemm(HMatrix<T> *c, char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1314   assert(c->isLeaf());
1315   assert(!c->isRkMatrix());
1316   if(!a->isLeaf() && !b->isLeaf()) {
1317     for (int i = 0; i < (transA=='N' ? a->nrChildRow() : a->nrChildCol()) ; i++) {
1318       for (int j = 0; j < (transB=='N' ? b->nrChildCol() : b->nrChildRow()) ; j++) {
1319         const HMatrix<T> *childA, *childB;
1320         for (int k = 0; k < (transA=='N' ? a->nrChildCol() : a->nrChildRow()) ; k++) {
1321           char tA = transA;
1322           char tB = transB;
1323           childA = a->getChildForGEMM(tA, i, k);
1324           childB = b->getChildForGEMM(tB, k, j);
1325           if(childA && childB)
1326             fullHHGemm(c, tA, tB, alpha, childA, childB);
1327         }
1328       }
1329     }
1330   } else if(!a->isRecursivelyNull() && !b->isRecursivelyNull()) {
1331     if(c->full() == NULL)
1332       c->full(new FullMatrix<T>(c->rows(), c->cols()));
1333     c->gemm(transA, transB, alpha, a, b, Constants<T>::pone);
1334   }
1335 }
1336 
1337 template<typename T> void
leafGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1338 HMatrix<T>::leafGemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1339     assert((transA == 'N' ? *a->cols() : *a->rows()) == ( transB == 'N' ? *b->rows() : *b->cols())); // pour le produit A*B
1340     assert((transA == 'N' ? *a->rows() : *a->cols()) == *this->rows()); // compatibility of A*B + this : Rows
1341     assert((transB == 'N' ? *b->cols() : *b->rows()) == *this->cols()); // compatibility of A*B + this : columns
1342 
1343     // One of the matrices is a leaf
1344     assert(this->isLeaf() || a->isLeaf() || b->isLeaf());
1345 
1346     // the resulting matrix is not a leaf.
1347     if (!this->isLeaf()) {
1348         // If the resulting matrix is subdivided then at least one of the matrices of the product is a leaf.
1349         // One matrix is a RkMatrix
1350         if (a->isRkMatrix() || b->isRkMatrix()) {
1351             if ((a->isRkMatrix() && a->isNull())
1352                     || (b->isRkMatrix() && b->isNull())) {
1353                 return;
1354             }
1355             RkMatrix<T>* rkMat = HMatrix<T>::multiplyRkMatrix(lowRankEpsilon(), transA, transB, a, b);
1356             axpy(alpha, rkMat);
1357             delete rkMat;
1358         } else {
1359             // None of the matrices of the product is a Rk-matrix so one of them is
1360             // a full matrix so as the result.
1361             assert(a->isFullMatrix() || b->isFullMatrix());
1362             FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix(transA, transB, a, b);
1363             if(fullMat) {
1364                 axpy(alpha, fullMat);
1365                 delete fullMat;
1366             }
1367         }
1368         return;
1369     }
1370 
1371     if (isRkMatrix()) {
1372         // The resulting matrix is a RkMatrix leaf.
1373         // At least one of the matrix is not a leaf.
1374         // The different cases are :
1375         //  a. R += H * H
1376         //  b. R += H * R
1377         //  c. R += R * H
1378         //  d. R += H * M
1379         //  e. R += M * H
1380         //  f. R += M * M
1381 
1382         // Cases a, b and c give an Hmatrix which has to be hierarchically converted into a Rkmatrix.
1383         // Cases c, d, e and f give a RkMatrix
1384         assert(isRkMatrix());
1385         assert((transA == 'N' ? *a->cols() : *a->rows()) == (transB == 'N' ? *b->rows() : *b->cols()));
1386         assert(*rows() == (transA == 'N' ? *a->rows() : *a->cols()));
1387         assert(*cols() == (transB == 'N' ? *b->cols() : *b->rows()));
1388         if(rk() == NULL)
1389             rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
1390         rk()->gemmRk(lowRankEpsilon(), transA, transB, alpha, a, b);
1391         rank_ = rk()->rank();
1392         return;
1393     }
1394 
1395     // a, b are H matrices and 'this' is full
1396     if ( this->isLeaf() && ((!a->isLeaf() && !b->isLeaf()) || isNull()) ) {
1397       fullHHGemm(this, transA, transB, alpha, a, b);
1398       return;
1399     }
1400 
1401     // The resulting matrix is a full matrix
1402     FullMatrix<T>* fullMat;
1403     if (a->isRkMatrix() || b->isRkMatrix()) {
1404         assert(a->isRkMatrix() || b->isRkMatrix());
1405         if ((a->isRkMatrix() && a->isNull())
1406                 || (b->isRkMatrix() && b->isNull())) {
1407             return;
1408         }
1409         RkMatrix<T>* rkMat = HMatrix<T>::multiplyRkMatrix(lowRankEpsilon(), transA, transB, a, b);
1410         fullMat = rkMat->eval();
1411         delete rkMat;
1412     } else if(a->isLeaf() && b->isLeaf() && isFullMatrix()){
1413         full()->gemm(transA, transB, alpha, a->full(), b->full(), Constants<T>::pone);
1414         return;
1415     } else {
1416       // if a or b is a leaf, it is Full (since Rk have been treated before)
1417         fullMat = HMatrix<T>::multiplyFullMatrix(transA, transB, a, b);
1418     }
1419 
1420     // It's not optimal to concider that the result is a FullMatrix but
1421     // this is a H*F case and it almost never happen
1422     if(fullMat) {
1423       if (isFullMatrix()) {
1424         full()->axpy(alpha, fullMat);
1425         delete fullMat;
1426       } else {
1427         full(fullMat);
1428         fullMat->scale(alpha);
1429       }
1430     }
1431 }
1432 
1433 template<typename T>
gemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b,T beta,MainOp)1434 void HMatrix<T>::gemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>* b, T beta, MainOp) {
1435   // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1436   if(isVoid() || a->isVoid())
1437       return;
1438 
1439   // This and B are Rk matrices with the same panel 'b' -> the gemm is only applied on the panels 'a'
1440   if(isRkMatrix() && !isNull() && b->isRkMatrix() && !b->isNull() && rk()->b == b->rk()->b) {
1441     // Ca * CbT = beta * Ca * CbT + alpha * A * Ba * BbT
1442     // As Cb = Bb we get
1443     // Ca = beta * Ca + alpha A * Ba with only Ca and Ba scalar arrays
1444     // We support C and B not compatible (larger) with A so we first slice them
1445     assert(transB == 'N');
1446     const IndexSet * r = transA == 'N' ? a->rows() : a->cols();
1447     const IndexSet * c = transA == 'N' ? a->cols() : a->rows();
1448     ScalarArray<T> cSubset(rk()->a->rowsSubset( r->offset() -    rows()->offset(), r->size()));
1449     ScalarArray<T> bSubset(b->rk()->a->rowsSubset( c->offset() - b->rows()->offset(), c->size()));
1450     a->gemv(transA, alpha, &bSubset, beta, &cSubset);
1451     return;
1452   }
1453 
1454   // This and A are Rk matrices with the same panel 'a' -> the gemm is only applied on the panels 'b'
1455   if(isRkMatrix() && !isNull() && a->isRkMatrix() && !a->isNull() && rk()->a == a->rk()->a) {
1456     // Ca * CbT = beta * Ca * CbT + alpha * Aa * AbT * B
1457     // As Ca = Aa we get
1458     // CbT = beta * CbT + alpha AbT * B with only Cb and Ab scalar arrays
1459     // we transpose:
1460     // Cb = beta * Cb + alpha BT * Ab
1461     // We support C and B not compatible (larger) with A so we first slice them
1462     assert(transA == 'N');
1463     assert(transB != 'C');
1464     const IndexSet * r = transB == 'N' ? b->rows() : b->cols();
1465     const IndexSet * c = transB == 'N' ? b->cols() : b->rows();
1466     ScalarArray<T> cSubset(rk()->b->rowsSubset( c->offset() -    cols()->offset(), c->size()));
1467     ScalarArray<T> aSubset(a->rk()->b->rowsSubset( r->offset() - a->cols()->offset(), r->size()));
1468     b->gemv(transB == 'N' ? 'T' : 'N', alpha, &aSubset, beta, &cSubset);
1469     return;
1470   }
1471 
1472   this->scale(beta);
1473 
1474   if((a->isLeaf() && (!a->isAssembled() || a->isNull())) ||
1475      (b->isLeaf() && (!b->isAssembled() || b->isNull()))) {
1476       if(!isAssembled() && this->isLeaf())
1477           rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
1478       return;
1479   }
1480 
1481   // Once the scaling is done, beta is reset to 1
1482   // to avoid an other scaling.
1483   recursiveGemm(transA, transB, alpha, a, b);
1484 }
1485 
1486 template<typename T>
multiplyFullH(char transM,char transH,const FullMatrix<T> * mat,const HMatrix<T> * h)1487 FullMatrix<T>* multiplyFullH(char transM, char transH,
1488                                          const FullMatrix<T>* mat,
1489                                          const HMatrix<T>* h) {
1490   assert(transH != 'C');
1491   FullMatrix<T>* resultT;
1492   if(transM == 'C') {
1493     // R = M* * H = (H^t * conj(M))^t
1494     FullMatrix<T>* matT = mat->copy();
1495     matT->conjugate();
1496     resultT = multiplyHFull(transH == 'N' ? 'T' : 'N',
1497                             'N', h, matT);
1498     delete matT;
1499   } else {
1500     // R = M * H = (H^t * M^t*)^t
1501     resultT = multiplyHFull(transH == 'N' ? 'T' : 'N',
1502                             transM == 'N' ? 'T' : 'N',
1503                             h, mat);
1504   }
1505   if(resultT != NULL)
1506     resultT->transpose();
1507   return resultT;
1508 }
1509 
isRecursivelyNull() const1510 template<typename T> bool HMatrix<T>::isRecursivelyNull() const {
1511   if(this->isLeaf())
1512     return isNull();
1513   else for(int i = 0; i < this->nrChild(); i++) {
1514     if(this->getChild(i) && !this->getChild(i)->isRecursivelyNull())
1515       return false;
1516   }
1517   return true;
1518 }
1519 
1520 template<typename T>
multiplyHFull(char transH,char transM,const HMatrix<T> * h,const FullMatrix<T> * mat)1521 FullMatrix<T>* multiplyHFull(char transH, char transM,
1522                                          const HMatrix<T>* h,
1523                                          const FullMatrix<T>* mat) {
1524   assert((transH == 'N' ? h->cols()->size() : h->rows()->size())
1525            == (transM == 'N' ? mat->rows() : mat->cols()));
1526   if(h->isRecursivelyNull())
1527     return NULL;
1528   FullMatrix<T>* result =
1529     new FullMatrix<T>((transH == 'N' ? h->rows() : h->cols()),
1530                       (transM == 'N' ? mat->cols_ : mat->rows_));
1531   if (transM == 'N') {
1532     h->gemv(transH, Constants<T>::pone, mat, Constants<T>::zero, result);
1533   } else {
1534     FullMatrix<T>* matT = mat->copyAndTranspose();
1535     if (transM == 'C') {
1536       matT->conjugate();
1537     }
1538     h->gemv(transH, Constants<T>::pone, matT, Constants<T>::zero, result);
1539     delete matT;
1540   }
1541   return result;
1542 }
1543 
1544 template<typename T>
multiplyRkMatrix(double epsilon,char transA,char transB,const HMatrix<T> * a,const HMatrix<T> * b)1545 RkMatrix<T>* HMatrix<T>::multiplyRkMatrix(double epsilon, char transA, char transB, const HMatrix<T>* a, const HMatrix<T>* b){
1546   // We know that one of the matrices is a RkMatrix
1547   assert(a->isRkMatrix() || b->isRkMatrix());
1548   RkMatrix<T> *rk = NULL;
1549   // Matrices range compatibility
1550   if((transA == 'N') && (transB == 'N'))
1551     assert(a->cols()->size() == b->rows()->size());
1552   if((transA != 'N') && (transB == 'N'))
1553     assert(a->rows()->size() == b->rows()->size());
1554   if((transA == 'N') && (transB != 'N'))
1555     assert(a->cols()->size() == b->cols()->size());
1556 
1557   // The cases are:
1558   //  - A Rk, B H
1559   //  - A H,  B Rk
1560   //  - A Rk, B Rk
1561   //  - A Rk, B F
1562   //  - A F,  B Rk
1563   if (a->isRkMatrix() && !b->isLeaf()) {
1564     rk = RkMatrix<T>::multiplyRkH(transA, transB, a->rk(), b);
1565     HMAT_ASSERT(rk);
1566   }
1567   else if (!a->isLeaf() && b->isRkMatrix()) {
1568     rk = RkMatrix<T>::multiplyHRk(transA, transB, a, b->rk());
1569     HMAT_ASSERT(rk);
1570   }
1571   else if (a->isRkMatrix() && b->isRkMatrix()) {
1572     rk = RkMatrix<T>::multiplyRkRk(transA, transB, a->rk(), b->rk(), epsilon);
1573     HMAT_ASSERT(rk);
1574   }
1575   else if (a->isRkMatrix() && b->isFullMatrix()) {
1576     rk = RkMatrix<T>::multiplyRkFull(transA, transB, a->rk(), b->full());
1577     HMAT_ASSERT(rk);
1578   }
1579   else if (a->isFullMatrix() && b->isRkMatrix()) {
1580     rk = RkMatrix<T>::multiplyFullRk(transA, transB, a->full(), b->rk());
1581     HMAT_ASSERT(rk);
1582   } else if(a->isNull() || b->isNull()) {
1583     return new RkMatrix<T>(NULL, transA ? a->cols() : a->rows(),
1584                            NULL, transB ? b->rows() : b->cols());
1585   } else {
1586     // None of the above cases, impossible.
1587     HMAT_ASSERT(false);
1588   }
1589   return rk;
1590 }
1591 
1592 template<typename T>
multiplyFullMatrix(char transA,char transB,const HMatrix<T> * a,const HMatrix<T> * b)1593 FullMatrix<T>* HMatrix<T>::multiplyFullMatrix(char transA, char transB,
1594                                               const HMatrix<T>* a,
1595                                               const HMatrix<T>* b) {
1596   // At least one full matrix, and not RkMatrix.
1597   assert(a->isFullMatrix() || b->isFullMatrix());
1598   assert(!(a->isRkMatrix() || b->isRkMatrix()));
1599   FullMatrix<T> *result = NULL;
1600   // The cases are:
1601   //  - A H, B F
1602   //  - A F, B H
1603   //  - A F, B F
1604   if (!a->isLeaf() && b->isFullMatrix()) {
1605     result = multiplyHFull(transA, transB, a, b->full());
1606   } else if (a->isFullMatrix() && !b->isLeaf()) {
1607     result = multiplyFullH(transA, transB, a->full(), b);
1608   } else if (a->isFullMatrix() && b->isFullMatrix()) {
1609     const IndexSet* aRows = (transA == 'N')? a->rows() : a->cols();
1610     const IndexSet* bCols = (transB == 'N')? b->cols() : b->rows();
1611     result = new FullMatrix<T>(aRows, bCols);
1612     result->gemm(transA, transB, Constants<T>::pone, a->full(), b->full(),
1613                  Constants<T>::zero);
1614   } else if(a->isNull() || b->isNull()) {
1615     return NULL;
1616   } else {
1617     // None of above, impossible
1618     HMAT_ASSERT(false);
1619   }
1620   return result;
1621 }
1622 
1623 template<typename T>
multiplyWithDiag(const HMatrix<T> * d,Side side,bool inverse) const1624 void HMatrix<T>::multiplyWithDiag(const HMatrix<T>* d, Side side, bool inverse) const {
1625   assert(*d->rows() == *d->cols());
1626   assert(side == Side::LEFT  || (*cols() == *d->rows()));
1627   assert(side == Side::RIGHT || (*rows() == *d->cols()));
1628 
1629   if (isVoid()) return;
1630 
1631   // The symmetric matrix must be taken into account: lower or upper
1632   if (!this->isLeaf()) {
1633     if (d->isLeaf()) {
1634       for (int i=0 ; i<std::min(nrChildRow(), nrChildCol()) ; i++)
1635         if(get(i,i))
1636           get(i,i)->multiplyWithDiag(d, side, inverse);
1637       for (int i=0 ; i<nrChildRow() ; i++)
1638         for (int j=0 ; j<nrChildCol() ; j++)
1639           if (i!=j && get(i,j)) {
1640             get(i,j)->multiplyWithDiag(d, side, inverse);
1641           }
1642       return;
1643     }
1644 
1645     // First the diagonal, then the rest...
1646     for (int i=0 ; i<std::min(nrChildRow(), nrChildCol()) ; i++)
1647       if(get(i,i))
1648         get(i,i)->multiplyWithDiag(d->get(i,i), side, inverse);
1649     for (int i=0 ; i<nrChildRow() ; i++)
1650       for (int j=0 ; j<nrChildCol() ; j++)
1651         if (i!=j && get(i,j)) {
1652         int k = side == Side::LEFT ? i : j;
1653         get(i,j)->multiplyWithDiag(d->get(k,k), side, inverse);
1654         // TODO couldn't we handle this case with the previous one, using getChildForGEMM(d,i,i) that returns 'd' itself when 'd' is a leaf ?
1655     }
1656   } else if (isRkMatrix() && !isNull()) {
1657     rk()->multiplyWithDiagOrDiagInv(d, inverse, side);
1658   } else if(isFullMatrix()){
1659     if (d->isFullMatrix()) {
1660       full()->multiplyWithDiagOrDiagInv(d->full()->diagonal, inverse, side);
1661     } else {
1662       Vector<T> diag(d->rows()->size());
1663       d->extractDiagonal(diag.ptr());
1664       full()->multiplyWithDiagOrDiagInv(&diag, inverse, side);
1665     }
1666   } else {
1667     // this is a null matrix (either full of Rk) so nothing to do
1668   }
1669 }
1670 
transposeMeta(bool temporaryOnly)1671 template<typename T> void HMatrix<T>::transposeMeta(bool temporaryOnly) {
1672     if(temporaryOnly && !temporary_)
1673       return;
1674     // called by HMatrix<T>::transpose() and HMatrixHandle<T>::transpose()
1675     // if the matrix is symmetric, inverting it(Upper/Lower)
1676     if (isLower || isUpper) {
1677         isLower = !isLower;
1678         isUpper = !isUpper;
1679     }
1680     // if the matrix is triangular, on inverting it (isTriUpper/isTriLower)
1681     if (isTriLower || isTriUpper) {
1682         isTriLower = !isTriLower;
1683         isTriUpper = !isTriUpper;
1684     }
1685     // Warning: nrChildRow() uses keepSameRows and rows_
1686     bool tmp = keepSameCols; // can't use swap on bitfield so manual swap...
1687     keepSameCols = keepSameRows;
1688     keepSameRows = tmp;
1689     swap(rows_, cols_);
1690     RecursionMatrix<T, HMatrix<T> >::transposeMeta(temporaryOnly);
1691 }
1692 
transposeData()1693 template <typename T> void HMatrix<T>::transposeData() {
1694     if (this->isLeaf()) {
1695         if (isRkMatrix() && rk()) {
1696             rk()->transpose();
1697         } else if (isFullMatrix()) {
1698             full()->transpose();
1699         }
1700     } else {
1701         for (int i = 0; i < this->nrChild(); i++)
1702             if (this->getChild(i))
1703                 this->getChild(i)->transposeData();
1704     }
1705 }
1706 
transpose()1707 template<typename T> void HMatrix<T>::transpose() {
1708     transposeData();
1709     transposeMeta();
1710 }
1711 
conjugate()1712 template<> void HMatrix<S_t>::conjugate() {}
conjugate()1713 template<> void HMatrix<D_t>::conjugate() {}
conjugate()1714 template<typename T> void HMatrix<T>::conjugate() {
1715   std::vector<const HMatrix<T> *> stack;
1716   stack.push_back(this);
1717   while(!stack.empty()) {
1718     const HMatrix<T> * m = stack.back();
1719     stack.pop_back();
1720     if(!m->isLeaf()) {
1721       for(int i = 0; i < m->nrChild(); i++) {
1722         if(m->getChild(i) != NULL)
1723           stack.push_back(m->getChild(i));
1724       }
1725     } else if(m->isNull()) {
1726       // nothing to do
1727     } else if(m->isRkMatrix()) {
1728       m->rk()->conjugate();
1729     } else {
1730       m->full()->conjugate();
1731     }
1732   }
1733 }
1734 
1735 template<typename T>
copyAndTranspose(const HMatrix<T> * o)1736 void HMatrix<T>::copyAndTranspose(const HMatrix<T>* o) {
1737   assert(o);
1738   assert(*this->rows() == *o->cols());
1739   assert(*this->cols() == *o->rows());
1740   assert(this->isLeaf() == o->isLeaf());
1741 
1742   if (this->isLeaf()) {
1743     if (o->isRkMatrix()) {
1744       assert(!isFullMatrix());
1745       if (rk()) {
1746         delete rk();
1747       }
1748       RkMatrix<T>* newRk = o->rk()->copy();
1749       newRk->transpose();
1750       rk(newRk);
1751     } else {
1752       if (isFullMatrix()) {
1753         delete full();
1754       }
1755       const FullMatrix<T>* oF = o->full();
1756       if(oF == NULL) {
1757         full(NULL);
1758       } else {
1759         full(oF->copyAndTranspose());
1760         if (oF->diagonal) {
1761           if (!full()->diagonal) {
1762             full()->diagonal = new Vector<T>(oF->rows());
1763             HMAT_ASSERT(full()->diagonal);
1764           }
1765           oF->diagonal->copy(full()->diagonal);
1766         }
1767       }
1768     }
1769   } else {
1770     for (int i=0 ; i<nrChildRow() ; i++)
1771       for (int j=0 ; j<nrChildCol() ; j++)
1772         if (get(i,j) && o->get(j, i))
1773           get(i, j)->copyAndTranspose(o->get(j, i));
1774   }
1775 }
1776 
1777 template<typename T>
truncate()1778 void HMatrix<T>::truncate() {
1779   if (this->isLeaf()) {
1780     if (this->isRkMatrix()) {
1781       if (rk()) {
1782         rk()->truncate(localSettings.epsilon_);
1783         rank_ = rk()->rank();
1784       }
1785     }
1786   } else {
1787     for (int i = 0; i < this->nrChild(); i++) {
1788       HMatrix<T>* child = this->getChild(i);
1789       if (child) {
1790         child->truncate();
1791       }
1792     }
1793   }
1794 }
1795 
1796 template<typename T>
rows() const1797 const ClusterData* HMatrix<T>::rows() const {
1798   return &(rows_->data);
1799 }
1800 
1801 template<typename T>
cols() const1802 const ClusterData* HMatrix<T>::cols() const {
1803   return &(cols_->data);
1804 }
1805 
copy() const1806 template<typename T> HMatrix<T>* HMatrix<T>::copy() const {
1807   HMatrix<T>* M=Zero(this);
1808   M->copy(this);
1809   return M;
1810 }
1811 
1812 // Copy the data of 'o' into 'this'
1813 // The structure of both H-matrix is supposed to be allready similar
1814 template<typename T>
copy(const HMatrix<T> * o)1815 void HMatrix<T>::copy(const HMatrix<T>* o) {
1816   DECLARE_CONTEXT;
1817 
1818   assert(*rows() == *o->rows());
1819   assert(*cols() == *o->cols());
1820 
1821   isLower = o->isLower;
1822   isUpper = o->isUpper;
1823   isTriUpper = o->isTriUpper;
1824   isTriLower = o->isTriLower;
1825   approximateRank_ = o->approximateRank_;
1826   if (this->isLeaf()) {
1827     assert(o->isLeaf());
1828     if (isAssembled() && isNull() && o->isNull()) {
1829       return;
1830     }
1831     // When the matrix was not allocated but only the structure
1832     if (o->isFullMatrix() && isFullMatrix()) {
1833       o->full()->copy(full());
1834     } else if(o->isFullMatrix()) {
1835       assert(!isAssembled() || isNull());
1836       full(o->full()->copy());
1837     } else if (o->isRkMatrix() && !rk()) {
1838       rk(new RkMatrix<T>(NULL, o->rk()->rows, NULL, o->rk()->cols));
1839     }
1840     assert((isRkMatrix() == o->isRkMatrix())
1841            && (isFullMatrix() == o->isFullMatrix()));
1842     if (o->isRkMatrix()) {
1843       rk()->copy(o->rk());
1844       rank_ = rk()->rank();
1845     }
1846   } else {
1847     assert(o->rank_==NONLEAF_BLOCK);
1848     rank_ = o->rank_;
1849     for (int i = 0; i < o->nrChild(); i++) {
1850         if (o->getChild(i)) {
1851           assert(this->getChild(i));
1852           this->getChild(i)->copy(o->getChild(i));
1853         } else {
1854           assert(!this->getChild(i));
1855       }
1856     }
1857   }
1858 }
1859 
1860 template<typename T>
lowRankEpsilon(double epsilon,bool recursive)1861 void HMatrix<T>::lowRankEpsilon(double epsilon, bool recursive) {
1862   localSettings.epsilon_ = epsilon;
1863   if(recursive && !this->isLeaf()) {
1864     for (int i = 0; i < this->nrChild(); i++) {
1865       HMatrix<T>* child = this->getChild(i);
1866       if (child)
1867         child->lowRankEpsilon(epsilon);
1868     }
1869   }
1870 }
1871 
clear()1872 template<typename T> void HMatrix<T>::clear() {
1873   if(!this->isLeaf()) {
1874     for (int i = 0; i < this->nrChild(); i++) {
1875       HMatrix<T>* child = this->getChild(i);
1876       if (child)
1877         child->clear();
1878     }
1879   } else if(isRkMatrix()) {
1880     if(rk())
1881       delete rk();
1882     rk(NULL);
1883   } else if(isFullMatrix()) {
1884     delete full();
1885     full(NULL);
1886   }
1887 }
1888 
1889 template<typename T>
inverse()1890 void HMatrix<T>::inverse() {
1891   DECLARE_CONTEXT;
1892 
1893   HMAT_ASSERT_MSG(!isLower, "HMatrix::inverse not available for symmetric matrices");
1894 
1895   if (this->isLeaf()) {
1896     assert(isFullMatrix());
1897     full()->inverse();
1898   } else {
1899 
1900     //  Matrix inversion:
1901     //  The idea to inverse M is to consider the extended matrix obtained by putting Identity next to M :
1902     //
1903     //  [ M11 | M12 |  I  |  0  ]
1904     //  [ ----+-----+-----+---- ]
1905     //  [ M21 | M22 |  0  |  I  ]
1906     //
1907     //  We then apply operations on the line of this matrix (matrix multiplication of an entire line,
1908     // linear combination of lines)
1909     // to transform the 'M' part into identity. Doing so, the identity part will at the end contain M-1.
1910     // We loop on the column of M.
1911     // At the end of loop 'k', the 'k' first columns of 'M' are now identity,
1912     // and the 'k' first columns of Identity have changed (it's no longer identity, it's not yet M-1).
1913     // The matrix 'this' stores the first 'k' block of the identity part of the extended matrix, and the last n-k blocks of the M part
1914     // At the end, 'this' contains M-1
1915 
1916     if (isLower) {
1917 
1918       vector<HMatrix<T>*> TM(nrChildCol());
1919     for (int k=0 ; k<nrChildRow() ; k++){
1920       // Inverse M_kk
1921       get(k,k)->inverse();
1922       // Update line 'k' = left-multiplied by M_kk-1
1923         for (int j=0 ; j<nrChildCol() ; j++) {
1924 
1925             // Mkj <- Mkk^-1 Mkj we use a temp matrix X because this type of product is not allowed with gemm (beta=0 erases Mkj before using it !)
1926           if (j<k) { // under the diag we store TMj=Mkj
1927             TM[j] = get(k,j)->copy();
1928             get(k,j)->gemm('N', 'N', Constants<T>::pone, get(k,k), TM[j], Constants<T>::zero);
1929           } else if (j>k) { // above the diag : Mkj = t Mjk, we store TMj=-Mjk.tMkk-1 = -Mjk.Mkk-1 (Mkk est sym)
1930             TM[j] = Zero(get(j,k));
1931             TM[j]->gemm('N', 'T', Constants<T>::mone, get(j,k), get(k,k), Constants<T>::zero);
1932           }
1933         }
1934       // Update the rest of matrix M
1935       for (int i=0 ; i<nrChildRow() ; i++)
1936           // line 'i' -= Mik x line 'k' (which has just been multiplied by Mkk-1)
1937         for (int j=0 ; j<nrChildCol() ; j++)
1938             if (i!=k && j!=k && j<=i) {
1939               // Mij <- Mij - Mik (Mkk^-1 Mkj) (with Mkk-1.Mkj allready stored in Mkj and TMj=Mjk.tMkk-1)
1940               // cas k < j <     i        Mkj n'existe pas, on prend t{TMj} = -Mkk-1Mkj
1941               if (k<j)
1942                 get(i,j)->gemm('N', 'T', Constants<T>::pone, get(i,k), TM[j], Constants<T>::pone);
1943               // cas     j < k < i        Toutes les matrices existent sous la diag
1944               else if (k<i)
1945             get(i,j)->gemm('N', 'N', Constants<T>::mone, get(i,k), get(k,j), Constants<T>::pone);
1946               // cas     j <     i < k    Mik n'existe pas, on prend TM[i] = Mki
1947               else
1948                 get(i,j)->gemm('T', 'N', Constants<T>::mone, TM[i], get(k,j), Constants<T>::pone);
1949             }
1950       // Update column 'k' = right-multiplied by -M_kk-1
1951       for (int i=0 ; i<nrChildRow() ; i++)
1952           if (i>k) {
1953           // Mik <- - Mik Mkk^-1
1954             get(i,k)->copy(TM[i]);
1955   }
1956         for (int j=0 ; j<nrChildCol() ; j++) {
1957           delete TM[j];
1958           TM[j]=NULL;
1959         }
1960       }
1961     } else {
1962       this->recursiveInverseNosym();
1963   }
1964   }
1965 }
1966 
1967 template<typename T>
solveLowerTriangularLeft(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo,MainOp) const1968 void HMatrix<T>::solveLowerTriangularLeft(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo, MainOp) const {
1969   DECLARE_CONTEXT;
1970   if (isVoid()) return;
1971   // At first, the recursion one (simple case)
1972   if (!this->isLeaf() && !b->isLeaf()) {
1973     this->recursiveSolveLowerTriangularLeft(b, algo, diag, uplo);
1974   } else if(!b->isLeaf()) {
1975     // B isn't a leaf, then 'this' is one
1976     assert(this->isLeaf());
1977     // Evaluate B as a full matrix, solve, and restore in the matrix
1978     // TODO: check if it's not too bad
1979     FullMatrix<T> bFull(b->rows(), b->cols());
1980     b->evalPart(&bFull, b->rows(), b->cols());
1981     this->solveLowerTriangularLeft(&bFull, algo, diag, uplo);
1982     b->clear();
1983     b->axpy(Constants<T>::pone, &bFull);
1984   } else if(b->isNull()) {
1985     // nothing to do
1986   } else {
1987     if (b->isFullMatrix()) {
1988       this->solveLowerTriangularLeft(b->full(), algo, diag, uplo);
1989     } else {
1990       assert(b->isRkMatrix());
1991       HMatrix<T> * bSubset = b->subset(uplo == Uplo::LOWER ? this->cols() : this->rows(), b->cols());
1992       this->solveLowerTriangularLeft(bSubset->rk()->a, algo, diag, uplo);
1993       if(bSubset != b)
1994           delete bSubset;
1995     }
1996   }
1997 }
1998 
1999 template<typename T>
solveLowerTriangularLeft(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2000 void HMatrix<T>::solveLowerTriangularLeft(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2001   DECLARE_CONTEXT;
2002   assert(*rows() == *cols());
2003   assert(cols()->size() == b->rows);
2004   if (isVoid()) return;
2005   if (this->isLeaf()) {
2006     assert(this->isFullMatrix());
2007     full()->solveLowerTriangularLeft(b, algo, diag, uplo);
2008   } else {
2009     //  Forward substitution:
2010     //  [ L11 |  0  ]   [ X1 ]   [ b1 ]
2011     //  [ ----+---- ] * [----] = [ -- ]
2012     //  [ L21 | L22 ]   [ X2 ]   [ b2 ]
2013     //
2014     //  L11 * X1 = b1 (by recursive forward substitution)
2015     //  L21 * X1 + L22 * X2 = b2 (forward substitution of L22*X2=b2-L21*X1)
2016     //
2017 
2018     int offset(0);
2019     vector<ScalarArray<T> > sub;
2020     for (int i=0 ; i<nrChildRow() ; i++) {
2021       // Create sub[i] = a ScalarArray (without copy of data) for the rows in front of the i-th matrix block
2022       sub.push_back(ScalarArray<T>(*b, offset, get(i, i)->cols()->size(), 0, b->cols));
2023       offset += get(i, i)->cols()->size();
2024       // Update sub[i] with the contribution of the solutions already computed sub[j] j<i
2025       for (int j=0 ; j<i ; j++) {
2026         const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2027         if (u_ji)
2028           u_ji->gemv(uplo == Uplo::LOWER ? 'N' : 'T', Constants<T>::mone, &sub[j], Constants<T>::pone, &sub[i]);
2029       }
2030       // Solve the i-th diagonal system
2031       get(i, i)->solveLowerTriangularLeft(&sub[i], algo, diag, uplo);
2032     }
2033   }
2034 }
2035 
2036 template<typename T>
solveLowerTriangularLeft(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2037 void HMatrix<T>::solveLowerTriangularLeft(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2038   solveLowerTriangularLeft(&b->data, algo, diag, uplo);
2039 }
2040 
2041 template<typename T>
solveUpperTriangularRight(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2042 void HMatrix<T>::solveUpperTriangularRight(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2043   DECLARE_CONTEXT;
2044   if (rows()->size() == 0 || cols()->size() == 0) return;
2045   // The recursion one (simple case)
2046   if (!this->isLeaf() && !b->isLeaf()) {
2047     this->recursiveSolveUpperTriangularRight(b, algo, diag, uplo);
2048   } else if(!b->isLeaf()) {
2049     // B isn't a leaf, then 'this' is one
2050     assert(this->isLeaf());
2051     assert(isFullMatrix());
2052     // Evaluate B, solve by column and restore all in the matrix
2053     // TODO: check if it's not too bad
2054     FullMatrix<T> bFull(b->rows(), b->cols());
2055     b->evalPart(&bFull, b->rows(), b->cols());
2056     this->solveUpperTriangularRight(&bFull, algo, diag, uplo);
2057     b->clear();
2058     b->axpy(Constants<T>::pone, &bFull);
2059   } else if(b->isNull()) {
2060     // nothing to do
2061   } else {
2062     if (b->isFullMatrix()) {
2063       this->solveUpperTriangularRight(b->full(), algo, diag, uplo);
2064     } else {
2065       assert(b->isRkMatrix());
2066       // Xa Xb^t U = Ba Bb^t
2067       //   - Xa = Ba
2068       //   - Xb^t U = Bb^t
2069       // Xb is stored without being transposed, thus we solve
2070       // U^t Xb = Bb instead
2071       HMatrix<T> * tmp = b->subset(b->rows(), uplo == Uplo::LOWER ? this->cols() : this->rows());
2072       this->solveLowerTriangularLeft(tmp->rk()->b, algo, diag, uplo);
2073       if(tmp != b)
2074           delete tmp;
2075     }
2076   }
2077 }
2078 
2079 template<typename T>
solveUpperTriangularRight(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2080 void HMatrix<T>::solveUpperTriangularRight(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2081   DECLARE_CONTEXT;
2082   assert(*rows() == *cols());
2083   assert(rows()->size() == b->cols);
2084   if (isVoid()) return;
2085   if (this->isLeaf()) {
2086     assert(this->isFullMatrix());
2087     full()->solveUpperTriangularRight(b, algo, diag, uplo);
2088   } else {
2089     //  Forward substitution:
2090     //                [ U11 | U12 ]
2091     //  [ X1 | X2 ] * [ ----+---- ] = [ b1 | b2 ]
2092     //                [  0  | U22 ]
2093     //
2094     //  X1 * U11 = b1 (by recursive forward substitution)
2095     //  X1 * U12 + X2 * U22 = b2 (forward substitution of X2*U22=b2-X1*U12)
2096     //
2097 
2098     int offset(0);
2099     vector<ScalarArray<T> > sub;
2100     for (int i=0 ; i<nrChildCol() ; i++) {
2101       // Create sub[i] = a ScalarArray (without copy of data) for the columns in front of the i-th matrix block
2102       sub.push_back(ScalarArray<T>(*b, 0, b->rows, offset, get(i, i)->rows()->size()));
2103       offset += get(i, i)->rows()->size();
2104       // Update sub[i] with the contribution of the solutions already computed sub[j]
2105       for (int j=0 ; j<i ; j++) {
2106         const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2107         if (u_ji)
2108           u_ji->gemv(uplo == Uplo::LOWER ? 'T' : 'N', Constants<T>::mone, &sub[j], Constants<T>::pone, &sub[i], Side::RIGHT);
2109       }
2110       // Solve the i-th diagonal system
2111       get(i, i)->solveUpperTriangularRight(&sub[i], algo, diag, uplo);
2112     }
2113   }
2114 }
2115 
2116 template<typename T>
solveUpperTriangularRight(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2117 void HMatrix<T>::solveUpperTriangularRight(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2118   solveUpperTriangularRight(&b->data, algo, diag, uplo);
2119 }
2120 
2121 /* Resolve U.X=B, solution saved in B, with B Hmat
2122    Only called by luDecomposition
2123  */
2124 template<typename T>
solveUpperTriangularLeft(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo,MainOp) const2125 void HMatrix<T>::solveUpperTriangularLeft(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo, MainOp) const {
2126   DECLARE_CONTEXT;
2127   if (rows()->size() == 0 || cols()->size() == 0) return;
2128   // At first, the recursion one (simple case)
2129   if (!this->isLeaf() && !b->isLeaf()) {
2130     this->recursiveSolveUpperTriangularLeft(b, algo, diag, uplo);
2131   } else if(!b->isLeaf()) {
2132     // B isn't a leaf, then 'this' is one
2133     assert(this->isLeaf());
2134     // Evaluate B, solve by column, and restore in the matrix
2135     // TODO: check if it's not too bad
2136     FullMatrix<T> bFull(b->rows(), b->cols());
2137     b->evalPart(&bFull, b->rows(), b->cols());
2138     this->solveUpperTriangularLeft(&bFull, algo, diag, uplo);
2139     b->clear();
2140     b->axpy(Constants<T>::pone, &bFull);
2141   } else if(b->isNull()) {
2142     // nothing to do
2143   } else {
2144     if (b->isFullMatrix()) {
2145       this->solveUpperTriangularLeft(b->full(), algo, diag, uplo);
2146     } else {
2147       assert(b->isRkMatrix());
2148       HMatrix * bSubset = b->subset(uplo == Uplo::LOWER ? this->rows() : this->cols(), b->cols());
2149       this->solveUpperTriangularLeft(bSubset->rk()->a, algo, diag, uplo);
2150       if(bSubset != b)
2151           delete bSubset;
2152     }
2153   }
2154 }
2155 
2156 template<typename T>
solveUpperTriangularLeft(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2157 void HMatrix<T>::solveUpperTriangularLeft(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2158   DECLARE_CONTEXT;
2159   assert(*rows() == *cols());
2160   assert(rows()->size() == b->rows || uplo == Uplo::UPPER);
2161   assert(cols()->size() == b->rows || uplo == Uplo::LOWER);
2162   if (rows()->size() == 0 || cols()->size() == 0) return;
2163   if (this->isLeaf()) {
2164     full()->solveUpperTriangularLeft(b, algo, diag, uplo);
2165   } else {
2166     //  Backward substitution:
2167     //  [ U11 | U12 ]   [ X1 ]   [ b1 ]
2168     //  [ ----+---- ] * [----] = [ -- ]
2169     //  [  0  | U22 ]   [ X2 ]   [ b2 ]
2170     //
2171     //  U22 * X2 = b12(by recursive backward substitution)
2172     //  U11 * X1 + U12 * X2 = b1 (backward substitution of U11*X1=b1-U12*X2)
2173     //
2174 
2175     int offset(0);
2176     vector<ScalarArray<T> > sub;
2177     for (int i=0 ; i<nrChildRow() ; i++) {
2178       // Create sub[i] = a ScalarArray (without copy of data) for the rows in front of the i-th matrix block
2179       sub.push_back(b->rowsSubset(offset, get(i, i)->cols()->size()));
2180       offset += get(i, i)->cols()->size();
2181     }
2182     for (int i=nrChildRow()-1 ; i>=0 ; i--) {
2183       // Solve the i-th diagonal system
2184       get(i, i)->solveUpperTriangularLeft(&sub[i], algo, diag, uplo);
2185       // Update sub[j] j<i with the contribution of the solutions just computed sub[i]
2186       for (int j=0 ; j<i ; j++) {
2187         const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2188         if (u_ji)
2189           u_ji->gemv(uplo == Uplo::LOWER ? 'T' : 'N', Constants<T>::mone, &sub[i], Constants<T>::pone, &sub[j]);
2190       }
2191     }
2192   }
2193 }
2194 
2195 template<typename T>
solveUpperTriangularLeft(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2196 void HMatrix<T>::solveUpperTriangularLeft(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2197   solveUpperTriangularLeft(&b->data, algo, diag, uplo);
2198 }
2199 
lltDecomposition(hmat_progress_t * progress)2200 template<typename T> void HMatrix<T>::lltDecomposition(hmat_progress_t * progress) {
2201 
2202     assertLower(this);
2203     if (isVoid()) {
2204         // nothing to do
2205     } else if(this->isLeaf()) {
2206         full()->lltDecomposition();
2207         if(progress != NULL) {
2208             progress->current= rows()->offset() + rows()->size();
2209             progress->update(progress);
2210         }
2211     } else {
2212         HMAT_ASSERT(isLower);
2213       this->recursiveLltDecomposition(progress);
2214     }
2215     isTriLower = true;
2216     isLower = false;
2217 }
2218 
2219 template<typename T>
luDecomposition(hmat_progress_t * progress)2220 void HMatrix<T>::luDecomposition(hmat_progress_t * progress) {
2221   DECLARE_CONTEXT;
2222 
2223   if (rows()->size() == 0 || cols()->size() == 0) return;
2224   if (this->isLeaf()) {
2225     assert(isFullMatrix());
2226     full()->luDecomposition();
2227     full()->checkNan();
2228     if(progress != NULL) {
2229       progress->current= rows()->offset() + rows()->size();
2230       progress->update(progress);
2231     }
2232   } else {
2233     this->recursiveLuDecomposition(progress);
2234   }
2235 }
2236 
2237 template<typename T>
mdntProduct(const HMatrix<T> * m,const HMatrix<T> * d,const HMatrix<T> * n)2238 void HMatrix<T>::mdntProduct(const HMatrix<T>* m, const HMatrix<T>* d, const HMatrix<T>* n) {
2239   DECLARE_CONTEXT;
2240 
2241   HMatrix<T>* x = m->copy();
2242   x->multiplyWithDiag(d); // x=M.D
2243   this->gemm('N', 'T', Constants<T>::mone, x, n, Constants<T>::pone); // this -= M.D.tN
2244   delete x;
2245 }
2246 
2247 template<typename T>
mdmtProduct(const HMatrix<T> * m,const HMatrix<T> * d)2248 void HMatrix<T>::mdmtProduct(const HMatrix<T>* m, const HMatrix<T>* d) {
2249   DECLARE_CONTEXT;
2250   if (isVoid() || d->isVoid() || m->isVoid()) return;
2251   // this <- this - M * D * M^T
2252   //
2253   // D is stored separately in full matrix of diagonal leaves (see full_matrix.hpp).
2254   // this is symmetric and stored as lower triangular.
2255   // Warning: d must be the result of an ldlt factorization
2256   assertLower(this);
2257   assert(*d->rows() == *d->cols());       // D is square
2258   assert(*this->rows() == *this->cols()); // this is square
2259   assert(*m->cols() == *d->rows());       // Check if we can have the produit M*D and D*M^T
2260   assert(*this->rows() == *m->rows());
2261 
2262   if(!this->isLeaf()) {
2263     if (!m->isLeaf()) {
2264       this->recursiveMdmtProduct(m, d);
2265     } else if (m->isRkMatrix() && !m->isNull()) {
2266       HMatrix<T>* m_copy = m->copy();
2267 
2268       assert(*m->cols() == *d->rows());
2269       assert(*m_copy->rk()->cols == *d->rows());
2270       m_copy->multiplyWithDiag(d); // right multiplication by D
2271       RkMatrix<T>* rkMat = RkMatrix<T>::multiplyRkRk('N', 'T', m_copy->rk(), m->rk(), m->lowRankEpsilon());
2272       delete m_copy;
2273 
2274       this->axpy(Constants<T>::mone, rkMat);
2275       delete rkMat;
2276     } else if(m->isFullMatrix()){
2277       HMatrix<T>* copy_m = m->copy();
2278       HMAT_ASSERT(copy_m);
2279       copy_m->multiplyWithDiag(d); // right multiplication by D
2280 
2281       FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix('N', 'T', copy_m, m);
2282       HMAT_ASSERT(fullMat);
2283       delete copy_m;
2284 
2285       this->axpy(Constants<T>::mone, fullMat);
2286       delete fullMat;
2287     } else {
2288       // m is a null matrix (either Rk or Full) so nothing to do.
2289     }
2290   } else {
2291     assert(isFullMatrix());
2292     if (m->isRkMatrix()) {
2293       // this : full
2294       // m    : rk
2295       // Strategy: compute mdm^T as FullMatrix and then do this<-this - mdm^T
2296 
2297       // 1) copy  m = AB^T : m_copy
2298       // 2) m_copy <- m_copy * D    (multiplyWithDiag)
2299       // 3) rkMat <- multiplyRkRk ( m_copy , m^T)
2300       // 4) fullMat <- evaluation as a FullMatrix of the product rkMat = (A*(D*B)^T) * (A*B^T)^T
2301       // 5) this <- this - fullMat
2302       if (!m->isNull()) {
2303         HMatrix<T>* m_copy = m->copy();
2304         m_copy->multiplyWithDiag(d);
2305 
2306         RkMatrix<T>* rkMat = RkMatrix<T>::multiplyRkRk('N', 'T', m_copy->rk(), m->rk(), m->lowRankEpsilon());
2307         FullMatrix<T>* fullMat = rkMat->eval();
2308         delete m_copy;
2309         delete rkMat;
2310         full()->axpy(Constants<T>::mone, fullMat);
2311         delete fullMat;
2312       }
2313     } else if (m->isFullMatrix()) {
2314       // S <- S - M*D*M^T
2315       assert(!full()->isTriUpper());
2316       assert(!full()->isTriLower());
2317       assert(!m->full()->isTriUpper());
2318       assert(!m->full()->isTriLower());
2319       FullMatrix<T> mTmp(m->rows(), m->cols());
2320       mTmp.copyMatrixAtOffset(m->full(), 0, 0);
2321       if (d->isFullMatrix()) {
2322         mTmp.multiplyWithDiagOrDiagInv(d->full()->diagonal, false, Side::RIGHT);
2323       } else {
2324         Vector<T> diag(d->cols()->size());
2325         d->extractDiagonal(diag.ptr());
2326         mTmp.multiplyWithDiagOrDiagInv(&diag, false, Side::RIGHT);
2327       }
2328       full()->gemm('N', 'T', Constants<T>::mone, &mTmp, m->full(), Constants<T>::pone);
2329     } else if (!m->isLeaf()){
2330       FullMatrix<T> mTmp(m->rows(), m->cols());
2331       m->evalPart(&mTmp, m->rows(), m->cols());
2332       FullMatrix<T> mTmpCopy(m->rows(), m->cols());
2333       mTmpCopy.copyMatrixAtOffset(&mTmp, 0, 0);
2334       if (d->isFullMatrix()) {
2335         mTmp.multiplyWithDiagOrDiagInv(d->full()->diagonal, false, Side::RIGHT);
2336       } else {
2337         Vector<T> diag(d->cols()->size());
2338         d->extractDiagonal(diag.ptr());
2339         mTmp.multiplyWithDiagOrDiagInv(&diag, false, Side::RIGHT);
2340       }
2341       full()->gemm('N', 'T', Constants<T>::mone, &mTmp, &mTmpCopy, Constants<T>::pone);
2342     }
2343   }
2344 }
2345 
assertLdlt(const HMatrix<T> * me)2346 template<typename T> void assertLdlt(const HMatrix<T> * me) {
2347     // Void block (row & col)
2348     if (me->rows()->size() == 0 && me->cols()->size() == 0) return;
2349 #ifdef DEBUG_LDLT
2350     assert(me->isTriLower);
2351     if (me->isLeaf()) {
2352         assert(me->isFullMatrix());
2353         assert(me->full()->diagonal);
2354     } else {
2355       for (int i=0 ; i<me->nrChildRow() ; i++)
2356         assertLdlt(me->get(i,i));
2357     }
2358 #else
2359     ignore_unused_arg(me);
2360 #endif
2361 }
2362 
assertLower(const HMatrix<T> * me)2363 template<typename T> void assertLower(const HMatrix<T> * me) {
2364 #ifdef DEBUG_LDLT
2365     if (me->isLeaf()) {
2366         return;
2367     } else {
2368         assert(me->isLower);
2369         for (int i=0 ; i<me->nrChildRow() ; i++)
2370           for (int j=0 ; j<me->nrChildCol() ; j++) {
2371             if (i<j) /* NULL above diag */
2372               assert(!me->get(i,j));
2373             if (i==j) /* Lower on diag */
2374               assertLower(me->get(i,i));
2375           }
2376     }
2377 #else
2378     ignore_unused_arg(me);
2379 #endif
2380 }
2381 
assertUpper(const HMatrix<T> * me)2382 template<typename T> void assertUpper(const HMatrix<T> * me) {
2383 #ifdef DEBUG_LDLT
2384     if (me->isLeaf()) {
2385         return;
2386     } else {
2387         assert(me->isUpper);
2388         for (int i=0 ; i<me->nrChildRow() ; i++)
2389           for (int j=0 ; j<me->nrChildCol() ; j++) {
2390             if (i==j) /* Upper on diag */
2391               assertUpper(me->get(i,i));
2392           }
2393     }
2394 #else
2395     ignore_unused_arg(me);
2396 #endif
2397 }
2398 
2399 template<typename T>
ldltDecomposition(hmat_progress_t * progress)2400 void HMatrix<T>::ldltDecomposition(hmat_progress_t * progress) {
2401   DECLARE_CONTEXT;
2402   assertLower(this);
2403 
2404   if (isVoid()) {
2405     // nothing to do
2406   } else if (this->isLeaf()) {
2407     //The basic case of the recursion is necessarily a full matrix leaf
2408     //since the recursion is done with *rows() == *cols().
2409 
2410     assert(isFullMatrix());
2411     full()->ldltDecomposition();
2412     if(progress != NULL) {
2413         progress->current= rows()->offset() + rows()->size();
2414         progress->update(progress);
2415     }
2416     assert(full()->diagonal);
2417   } else {
2418     this->recursiveLdltDecomposition(progress);
2419   }
2420   isTriLower = true;
2421   isLower = false;
2422 }
2423 
2424 template<typename T>
solve(ScalarArray<T> * b) const2425 void HMatrix<T>::solve(ScalarArray<T>* b) const {
2426   DECLARE_CONTEXT;
2427   // Solve (LU) X = b
2428   // First compute L Y = b
2429   this->solveLowerTriangularLeft(b, Factorization::LU, Diag::UNIT, Uplo::LOWER);
2430   // Then compute U X = Y
2431   this->solveUpperTriangularLeft(b, Factorization::LU, Diag::NONUNIT, Uplo::UPPER);
2432 }
2433 
2434 template<typename T>
solve(FullMatrix<T> * b) const2435 void HMatrix<T>::solve(FullMatrix<T>* b) const {
2436   solve(&b->data);
2437 }
2438 
2439 template<typename T>
trsm(char side,char uplo,char trans,char diag,T alpha,HMatrix<T> * B) const2440 void HMatrix<T>::trsm( char side, char uplo, char trans, char diag,
2441 		       T alpha, HMatrix<T>* B ) const {
2442 
2443     bool upper   = (uplo == 'u') || (uplo == 'U');
2444     bool left    = (side == 'l') || (side == 'L');
2445     Diag unit    = (diag == 'u' || diag == 'U') ? Diag::UNIT : Diag::NONUNIT;
2446     bool notrans = (trans == 'n') || (trans == 'N');
2447 
2448     /* Upper case */
2449     if ( upper  ) {
2450 	if ( left ) {
2451 	    if ( notrans ) {
2452 		/* LUN */
2453 		solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::UPPER );
2454 	    }
2455 	    else {
2456 		/* LUT */
2457 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM LUT case is for now missing !!!" );
2458 	    }
2459 	}
2460 	else {
2461 	    if ( notrans ) {
2462 		/* RUN */
2463 		solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::UPPER );
2464 	    }
2465 	    else {
2466 		/* RUT */
2467 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM RUT case is for now missing !!!" );
2468 	    }
2469 	}
2470     }
2471     else {
2472 	if ( left ) {
2473 	    if ( notrans ) {
2474 		/* LLN */
2475 		solveLowerTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2476 	    }
2477 	    else {
2478 		/* LLT */
2479 		solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2480 	    }
2481 	}
2482 	else {
2483 	    if ( notrans ) {
2484 		/* RLN */
2485 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM RLN case is for now missing !!!" );
2486 	    }
2487 	    else {
2488 		/* RLT */
2489 		solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::LOWER );
2490 	    }
2491 	}
2492     }
2493 }
2494 
2495 template<typename T>
trsm(char side,char uplo,char trans,char diag,T alpha,ScalarArray<T> * B) const2496 void HMatrix<T>::trsm( char side, char uplo, char trans, char diag,
2497 		       T alpha, ScalarArray<T>* B ) const {
2498 
2499     bool upper   = (uplo == 'u') || (uplo == 'U');
2500     bool left    = (side == 'l') || (side == 'L');
2501     Diag unit    = (diag == 'u' || diag == 'U') ? Diag::UNIT : Diag::NONUNIT;
2502     bool notrans = (trans == 'n') || (trans == 'N');
2503 
2504     /* Upper case */
2505     if ( upper  ) {
2506 	if ( left ) {
2507 	    if ( notrans ) {
2508 		/* LUN */
2509 		solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::UPPER );
2510 	    }
2511 	    else {
2512 		/* LUT */
2513 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM LUT case is for now missing !!!" );
2514 	    }
2515 	}
2516 	else {
2517 	    if ( notrans ) {
2518 		/* RUN */
2519 		solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::UPPER );
2520 	    }
2521 	    else {
2522 		/* RUT */
2523 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM RUT case is for now missing !!!" );
2524 	    }
2525 	}
2526     }
2527     else {
2528 	if ( left ) {
2529 	    if ( notrans ) {
2530 		/* LLN */
2531 		solveLowerTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2532 	    }
2533 	    else {
2534 		/* LLT */
2535 		solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2536 	    }
2537 	}
2538 	else {
2539 	    if ( notrans ) {
2540 		/* RLN */
2541 		HMAT_ASSERT_MSG( 0, "ERROR: TRSM RLN case is for now missing !!!" );
2542 	    }
2543 	    else {
2544 		/* RLT */
2545 		solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::LOWER );
2546 	    }
2547 	}
2548     }
2549 }
2550 
2551 template<typename T>
extractDiagonal(T * diag) const2552 void HMatrix<T>::extractDiagonal(T* diag) const {
2553   DECLARE_CONTEXT;
2554   if (rows()->size() == 0 || cols()->size() == 0) return;
2555   if(this->isLeaf()) {
2556     assert(isFullMatrix());
2557     if(full()->diagonal) {
2558       // LDLt
2559       memcpy(diag, full()->diagonal->const_ptr(), full()->rows() * sizeof(T));
2560     } else {
2561       // LLt
2562       for (int i = 0; i < full()->rows(); ++i)
2563         diag[i] = full()->get(i,i);
2564     }
2565   } else {
2566     for (int i=0 ; i<nrChildRow() ; i++) {
2567       get(i,i)->extractDiagonal(diag);
2568       diag += get(i,i)->rows()->size();
2569     }
2570   }
2571 }
2572 
2573 /* Solve M.X=B with M hmat LU factorized*/
solve(HMatrix<T> * b,Factorization algo) const2574 template<typename T> void HMatrix<T>::solve(
2575         HMatrix<T>* b,
2576         Factorization algo) const {
2577     DECLARE_CONTEXT;
2578     switch(algo) {
2579     case Factorization::LU:
2580         /* Solve LX=B, result in B */
2581         this->solveLowerTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2582         /* Solve UX=B, result in B */
2583         this->solveUpperTriangularLeft(b, algo, Diag::NONUNIT, Uplo::UPPER);
2584         break;
2585     case Factorization::LDLT:
2586         /* Solve LX=B, result in B */
2587         this->solveLowerTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2588         /* Solve DX=B, result in B */
2589         b->multiplyWithDiag(this, Side::LEFT, true);
2590         /* Solve L^tX=B, result in B */
2591         this->solveUpperTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2592         break;
2593     case Factorization::LLT:
2594         /* Solve LX=B, result in B */
2595         this->solveLowerTriangularLeft(b, algo, Diag::NONUNIT, Uplo::LOWER);
2596         /* Solve L^tX=B, result in B */
2597         this->solveUpperTriangularLeft(b, algo, Diag::NONUNIT, Uplo::UPPER);
2598         break;
2599     default:
2600         HMAT_ASSERT(false);
2601     }
2602 }
2603 
solveDiagonal(ScalarArray<T> * b) const2604 template<typename T> void HMatrix<T>::solveDiagonal(ScalarArray<T>* b) const {
2605     // Solve D*X = B and store result into B
2606     // Diagonal extraction
2607     if (rows()->size() == 0 || cols()->size() == 0) return;
2608     if(isFullMatrix() && full()->diagonal) {
2609       // LDLt
2610       b->multiplyWithDiagOrDiagInv(full()->diagonal, true, Side::LEFT); // multiply to the left by the inverse
2611     } else {
2612       // LLt
2613       Vector<T>* diag = new Vector<T>(cols()->size());
2614       extractDiagonal(diag->ptr());
2615       b->multiplyWithDiagOrDiagInv(diag, true, Side::LEFT); // multiply to the left by the inverse
2616       delete diag;
2617     }
2618 }
2619 
solveDiagonal(FullMatrix<T> * b) const2620 template<typename T> void HMatrix<T>::solveDiagonal(FullMatrix<T>* b) const {
2621   solveDiagonal(&b->data);
2622 }
2623 
2624 template<typename T>
solveLdlt(ScalarArray<T> * b) const2625 void HMatrix<T>::solveLdlt(ScalarArray<T>* b) const {
2626   DECLARE_CONTEXT;
2627   assertLdlt(this);
2628   // L*D*L^T * X = B
2629   // B <- solution of L * Y = B : Y = D*L^T * X
2630   this->solveLowerTriangularLeft(b, Factorization::LDLT, Diag::UNIT, Uplo::LOWER);
2631 
2632   // B <- D^{-1} Y : solution of D*Y = B : Y = L^T * X
2633   this->solveDiagonal(b);
2634 
2635   // B <- solution of L^T X = B :  the solution X we are looking for is stored in B
2636   this->solveUpperTriangularLeft(b, Factorization::LDLT, Diag::UNIT, Uplo::LOWER);
2637 }
2638 
2639 template<typename T>
solveLdlt(FullMatrix<T> * b) const2640 void HMatrix<T>::solveLdlt(FullMatrix<T>* b) const {
2641   solveLdlt(&b->data);
2642 }
2643 
2644 template<typename T>
solveLlt(ScalarArray<T> * b) const2645 void HMatrix<T>::solveLlt(ScalarArray<T>* b) const {
2646   DECLARE_CONTEXT;
2647   // L*L^T * X = B
2648   // B <- solution of L * Y = B : Y = L^T * X
2649   this->solveLowerTriangularLeft(b, Factorization::LLT, Diag::NONUNIT, Uplo::LOWER);
2650 
2651   // B <- solution of L^T X = B :  the solution X we are looking for is stored in B
2652   this->solveUpperTriangularLeft(b, Factorization::LLT, Diag::NONUNIT, Uplo::LOWER);
2653 }
2654 
2655 template<typename T>
solveLlt(FullMatrix<T> * b) const2656 void HMatrix<T>::solveLlt(FullMatrix<T>* b) const {
2657   solveLlt(&b->data);
2658 }
2659 
2660 template<typename T>
checkStructure() const2661 void HMatrix<T>::checkStructure() const {
2662 #if 0
2663   if (this->isLeaf()) {
2664     return;
2665   }
2666   for (int i = 0; i < this->nrChild(); i++) {
2667       HMatrix<T>* child = this->getChild(i);
2668       if (child) {
2669         assert(child->rows()->isSubset(*(this->rows())) && child->cols()->isSubset(*(this->cols())));
2670         child->checkStructure();
2671     }
2672   }
2673 #endif
2674 }
2675 
2676 template<typename T>
checkNan() const2677 void HMatrix<T>::checkNan() const {
2678 #if 0
2679   if (this->isLeaf()) {
2680     if (isFullMatrix()) {
2681       full()->checkNan();
2682     }
2683     if (isRkMatrix()) {
2684       rk()->checkNan();
2685     }
2686   } else {
2687     for (int i = 0; i < this->nrChild(); i++) {
2688         if (this->getChild(i)) {
2689           this->getChild(i)->checkNan();
2690       }
2691     }
2692   }
2693 #endif
2694 }
2695 
setTriLower(bool value)2696 template<typename T> void HMatrix<T>::setTriLower(bool value)
2697 {
2698     isTriLower = value;
2699     if(!this->isLeaf())
2700     {
2701       for (int i = 0; i < nrChildRow(); i++)
2702         get(i, i)->setTriLower(value);
2703     }
2704 }
2705 
setLower(bool value)2706 template<typename T> void HMatrix<T>::setLower(bool value)
2707 {
2708     isLower = value;
2709     if(!this->isLeaf())
2710     {
2711       for (int i = 0; i < nrChildRow(); i++)
2712         get(i, i)->setLower(value);
2713     }
2714 }
2715 
rk(const ScalarArray<T> * a,const ScalarArray<T> * b)2716 template<typename T>  void HMatrix<T>::rk(const ScalarArray<T> * a, const ScalarArray<T> * b) {
2717     if(!isAssembled())
2718         rk(NULL);
2719     assert(isRkMatrix());
2720     if(a == NULL && isNull())
2721         return;
2722     delete rk_;
2723     rk(new RkMatrix<T>(a == NULL ? NULL : a->copy(), rows(),
2724                        b == NULL ? NULL : b->copy(), cols()));
2725 }
2726 
listAllLeaves(std::deque<const HMatrix<T> * > & out) const2727 template<typename T> void HMatrix<T>::listAllLeaves(std::deque<const HMatrix<T> *> & out) const {
2728   std::vector<const HMatrix<T> *> stack;
2729   stack.push_back(this);
2730   while(!stack.empty()) {
2731     const HMatrix<T> * m = stack.back();
2732     stack.pop_back();
2733     if(m->isLeaf()) {
2734       out.push_back(m);
2735     } else {
2736       for(int i = 0; i < m->nrChild(); i++) {
2737         if(m->getChild(i) != nullptr)
2738           stack.push_back(m->getChild(i));
2739       }
2740     }
2741   }
2742 }
2743 
2744 // No way to avoid copy/past of the const version
listAllLeaves(std::deque<HMatrix<T> * > & out)2745 template<typename T> void HMatrix<T>::listAllLeaves(std::deque<HMatrix<T> *> & out) {
2746   std::vector<HMatrix<T> *> stack;
2747   stack.push_back(this);
2748   while(!stack.empty()) {
2749     HMatrix<T> * m = stack.back();
2750     stack.pop_back();
2751     if(m->isLeaf()) {
2752       out.push_back(m);
2753     } else {
2754       for(int i = 0; i < m->nrChild(); i++) {
2755         if(m->getChild(i) != nullptr)
2756           stack.push_back(m->getChild(i));
2757       }
2758     }
2759   }
2760 }
2761 
toString() const2762 template<typename T> std::string HMatrix<T>::toString() const {
2763     std::deque<const HMatrix<T> *> leaves;
2764     this->listAllLeaves(leaves);
2765     int nbAssembled = 0;
2766     int nbNullFull = 0;
2767     int nbNullRk = 0;
2768     double diagNorm = 0;
2769     for(unsigned int i = 0; i < leaves.size(); i++) {
2770         const HMatrix<T> * l = leaves[i];
2771         if(l->isAssembled()) {
2772             nbAssembled++;
2773             if(l->isNull()) {
2774                 if(l->isRkMatrix())
2775                     nbNullRk++;
2776                 else
2777                     nbNullFull++;
2778             }
2779             else if(l->isFullMatrix() && l->full()->diagonal) {
2780                 diagNorm += l->full()->diagonal->normSqr();
2781             }
2782         }
2783     }
2784     diagNorm = sqrt(diagNorm);
2785     std::stringstream sstm;
2786     sstm << "HMatrix(rows=[" << rows()->offset() << ", " << rows()->size() <<
2787             "], cols=[" << cols()->offset() << ", " << cols()->size() <<
2788             "], pointer=" << (void*)this << ", leaves=" << leaves.size() <<
2789             ", assembled=" << isAssembled() << ", assembledLeaves=" << nbAssembled <<
2790             ", nullFull=" << nbNullFull << ", nullRk=" << nbNullRk <<
2791             ", rank=" << rank_ << ", diagNorm=" << diagNorm << ")";
2792     return sstm.str();
2793 }
2794 
2795 template<typename T>
unmarshall(const MatrixSettings * settings,int rank,int approxRank,char bitfield,double epsilon)2796 HMatrix<T> * HMatrix<T>::unmarshall(const MatrixSettings * settings, int rank, int approxRank, char bitfield, double epsilon) {
2797     HMatrix<T> * m = new HMatrix<T>(settings);
2798     m->rank_ = rank;
2799     m->isUpper = (bitfield & 1 << 0 ? true : false);
2800     m->isLower = (bitfield & 1 << 1 ? true : false);
2801     m->isTriUpper = (bitfield & 1 << 2 ? true : false);
2802     m->isTriLower = (bitfield & 1 << 3 ? true : false);
2803     m->keepSameRows = (bitfield & 1 << 4 ? true : false);
2804     m->keepSameCols = (bitfield & 1 << 5 ? true : false);
2805     m->approximateRank_ = approxRank;
2806     m->lowRankEpsilon(epsilon, false);
2807     return m;
2808 }
2809 
2810 /** Create a temporary block from a list of children */
2811 template<typename T>
HMatrix(const ClusterTree * rows,const ClusterTree * cols,std::vector<HMatrix * > & _children)2812 HMatrix<T>::HMatrix(const ClusterTree * rows, const ClusterTree * cols,
2813                     std::vector<HMatrix*> & _children):
2814     Tree<HMatrix<T> >(NULL, 0), rows_(rows), cols_(cols),
2815     rk_(NULL), rank_(UNINITIALIZED_BLOCK),
2816     approximateRank_(UNINITIALIZED_BLOCK), isUpper(false), isLower(false),
2817     keepSameRows(false), keepSameCols(false), temporary_(true), ownRowsClusterTree_(false),
2818     ownColsClusterTree_(false), localSettings(_children[0]->localSettings.global, -1.0) {
2819     this->children = _children;
2820 }
2821 
rank(int rank)2822 template<typename T> void HMatrix<T>::rank(int rank) {
2823     HMAT_ASSERT_MSG(rank_ >= 0, "HMatrix::rank can only be used on Rk blocks");
2824     HMAT_ASSERT_MSG(!rk() || rk()->a == NULL || rk()->rank() == rank,
2825         "HMatrix::rank can only be used on evicted blocks");
2826     rank_ = rank;
2827 }
2828 
2829 
temporary(bool b)2830 template<typename T> void HMatrix<T>::temporary(bool b) {
2831   temporary_ = b;
2832   for (int i=0; i<this->nrChild(); i++) {
2833     if (this->getChild(i))
2834       this->getChild(i)->temporary(b);
2835   }
2836 }
2837 
2838 // Templates declaration
2839 template class HMatrix<S_t>;
2840 template class HMatrix<D_t>;
2841 template class HMatrix<C_t>;
2842 template class HMatrix<Z_t>;
2843 
2844 template void reorderVector(ScalarArray<S_t>* v, int* indices, int axis);
2845 template void reorderVector(ScalarArray<D_t>* v, int* indices, int axis);
2846 template void reorderVector(ScalarArray<C_t>* v, int* indices, int axis);
2847 template void reorderVector(ScalarArray<Z_t>* v, int* indices, int axis);
2848 
2849 template void restoreVectorOrder(ScalarArray<S_t>* v, int* indices, int axis);
2850 template void restoreVectorOrder(ScalarArray<D_t>* v, int* indices, int axis);
2851 template void restoreVectorOrder(ScalarArray<C_t>* v, int* indices, int axis);
2852 template void restoreVectorOrder(ScalarArray<Z_t>* v, int* indices, int axis);
2853 
2854 template unsigned char * compatibilityGridForGEMM(const HMatrix<S_t>* a, Axis axisA, char transA, const HMatrix<S_t>* b, Axis axisB, char transB);
2855 template unsigned char * compatibilityGridForGEMM(const HMatrix<D_t>* a, Axis axisA, char transA, const HMatrix<D_t>* b, Axis axisB, char transB);
2856 template unsigned char * compatibilityGridForGEMM(const HMatrix<C_t>* a, Axis axisA, char transA, const HMatrix<C_t>* b, Axis axisB, char transB);
2857 template unsigned char * compatibilityGridForGEMM(const HMatrix<Z_t>* a, Axis axisA, char transA, const HMatrix<Z_t>* b, Axis axisB, char transB);
2858 
2859 }  // end namespace hmat
2860 
2861 #include "recursion.cpp"
2862 
2863 namespace hmat {
2864 
2865   // Explicit template instantiation
2866   template class RecursionMatrix<S_t, HMatrix<S_t> >;
2867   template class RecursionMatrix<C_t, HMatrix<C_t> >;
2868   template class RecursionMatrix<D_t, HMatrix<D_t> >;
2869   template class RecursionMatrix<Z_t, HMatrix<Z_t> >;
2870 
2871 }  // end namespace hmat
2872