1 /*
2 HMat-OSS (HMatrix library, open source software)
3
4 Copyright (C) 2014-2015 Airbus Group SAS
5
6 This program is free software; you can redistribute it and/or
7 modify it under the terms of the GNU General Public License
8 as published by the Free Software Foundation; either version 2
9 of the License, or (at your option) any later version.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19
20 http://github.com/jeromerobert/hmat-oss
21 */
22
23 #include "config.h"
24
25 /*! \file
26 \ingroup HMatrix
27 \brief HMatrix type.
28 */
29 #include <algorithm>
30 #include <list>
31 #include <vector>
32 #include <cstring>
33
34 #include "h_matrix.hpp"
35 #include "cluster_tree.hpp"
36 #include "admissibility.hpp"
37 #include "data_types.hpp"
38 #include "compression.hpp"
39 #include "recursion.hpp"
40 #include "common/context.hpp"
41 #include "common/my_assert.h"
42 #include "json.hpp"
43
44 using namespace std;
45
46 namespace hmat {
47
48 // The default values below will be overwritten in default_engine.cpp by HMatSettings values
49 template<typename T> bool HMatrix<T>::coarsening = false;
50 template<typename T> bool HMatrix<T>::recompress = false;
51 template<typename T> bool HMatrix<T>::validateNullRowCol = false;
52 template<typename T> bool HMatrix<T>::validateCompression = false;
53 template<typename T> bool HMatrix<T>::validationReRun = false;
54 template<typename T> bool HMatrix<T>::validationDump = false;
55 template<typename T> double HMatrix<T>::validationErrorThreshold = 0;
56
~HMatrix()57 template<typename T> HMatrix<T>::~HMatrix() {
58 if (isRkMatrix() && rk_) {
59 delete rk_;
60 rk_ = NULL;
61 }
62 if (full_) {
63 delete full_;
64 full_ = NULL;
65 }
66 if(ownRowsClusterTree_)
67 delete rows_;
68 if(ownColsClusterTree_)
69 delete cols_;
70 }
71
72
73 template<typename T>
reorderVector(ScalarArray<T> * v,int * indices,int axis)74 void reorderVector(ScalarArray<T>* v, int* indices, int axis) {
75 DECLARE_CONTEXT;
76 if (!indices) return;
77 const int n = axis == 0 ? v->rows : v->cols;
78 // If permutation is identity, do nothing
79 bool identity = true;
80 for (int i = 0; i < n; i++) {
81 if (indices[i] != i) {
82 identity = false;
83 break;
84 }
85 }
86 if (identity) return;
87
88 if (axis == 0) {
89 Vector<T> tmp(n);
90 for (int col = 0; col < v->cols; col++) {
91 Vector<T> column(*v, col);
92 for (int i = 0; i < n; i++)
93 tmp[i] = column[indices[i]];
94 tmp.copy(&column);
95 }
96 } else {
97 ScalarArray<T> tmp(1, n);
98 for (int row = 0; row < v->rows; row++) {
99 ScalarArray<T> column(*v, row, 1, 0, n);
100 for (int i = 0; i < n; i++)
101 tmp.get(0, i) = column.get(0, indices[i]);
102 tmp.copy(&column);
103 }
104 }
105 }
106
107 template<typename T>
restoreVectorOrder(ScalarArray<T> * v,int * indices,int axis)108 void restoreVectorOrder(ScalarArray<T>* v, int* indices, int axis) {
109 DECLARE_CONTEXT;
110 if (!indices) return;
111 const int n = axis == 0 ? v->rows : v->cols;
112 // If permutation is identity, do nothing
113 bool identity = true;
114 for (int i = 0; i < n; i++) {
115 if (indices[i] != i) {
116 identity = false;
117 break;
118 }
119 }
120 if (identity) return;
121
122 if (axis == 0) {
123 Vector<T> tmp(n);
124 for (int col = 0; col < v->cols; col++) {
125 Vector<T> column(*v, col);
126 for (int i = 0; i < n; i++) {
127 tmp[indices[i]] = column[i];
128 }
129 tmp.copy(&column);
130 }
131 } else {
132 ScalarArray<T> tmp(1, n);
133 for (int row = 0; row < v->rows; row++) {
134 ScalarArray<T> column(*v, row, 1, 0, n);
135 for (int i = 0; i < n; i++)
136 tmp.get(0, indices[i]) = column.get(0, i);
137 tmp.copy(&column);
138 }
139 }
140 }
141
142 template<typename T>
HMatrix(const ClusterTree * _rows,const ClusterTree * _cols,const hmat::MatrixSettings * settings,int _depth,SymmetryFlag symFlag,AdmissibilityCondition * admissibilityCondition)143 HMatrix<T>::HMatrix(const ClusterTree* _rows, const ClusterTree* _cols, const hmat::MatrixSettings * settings,
144 int _depth, SymmetryFlag symFlag, AdmissibilityCondition * admissibilityCondition)
145 : Tree<HMatrix<T> >(NULL, _depth), RecursionMatrix<T, HMatrix<T> >(),
146 rows_(_rows), cols_(_cols), rk_(NULL),
147 rank_(UNINITIALIZED_BLOCK), approximateRank_(UNINITIALIZED_BLOCK),
148 isUpper(false), isLower(false),
149 isTriUpper(false), isTriLower(false), keepSameRows(true), keepSameCols(true), temporary_(false),
150 ownRowsClusterTree_(false), ownColsClusterTree_(false), localSettings(settings, 1e-4)
151 {
152 if (isVoid())
153 return;
154 const bool lowRank = admissibilityCondition->isLowRank(*rows_, *cols_);
155 if (!split(admissibilityCondition, lowRank, symFlag)) {
156 // If we cannot split, we are on a leaf
157 const bool forceFull = admissibilityCondition->forceFull(*rows_, *cols_);
158 const bool forceRk = admissibilityCondition->forceRk(*rows_, *cols_);
159 assert(!(forceFull && forceRk));
160 if (forceRk || (lowRank && !forceFull))
161 rk(NULL);
162 else
163 full(NULL);
164 approximateRank_ = admissibilityCondition->getApproximateRank(*(rows_), *(cols_));
165 }
166 assert(!this->isLeaf() || isAssembled());
167 }
168
169 template<typename T>
split(AdmissibilityCondition * admissibilityCondition,bool lowRank,SymmetryFlag symFlag)170 bool HMatrix<T>::split(AdmissibilityCondition * admissibilityCondition, bool lowRank,
171 SymmetryFlag symFlag) {
172 assert(rank_ == NONLEAF_BLOCK || rank_ == UNINITIALIZED_BLOCK || (this->isLeaf() && isNull()));
173 // We would like to create a block of matrix in one of the following case:
174 // - rows_->isLeaf() && cols_->isLeaf() : both rows and cols are leaves.
175 // - Block is too small to recurse and compress (for performance)
176 // - Block is compressible and in-place compression is possible
177 // In the other cases, we subdivide.
178 //
179 // FIXME: But in practice this does not work yet, so we stop recursion as soon as either rows
180 // or cols is a leaf.
181 bool stopRecursion = admissibilityCondition->stopRecursion(*rows_, *cols_);
182 bool forceRecursion = admissibilityCondition->forceRecursion(*rows_, *cols_, sizeof(T));
183 assert(!(forceRecursion && stopRecursion));
184 // check we can actually split
185 if ((rows_->isLeaf() && cols_->isLeaf()) || stopRecursion || (lowRank && !forceRecursion))
186 return false;
187 pair<bool, bool> splitRC = admissibilityCondition->splitRowsCols(*rows_, *cols_);
188 assert(splitRC.first || splitRC.second);
189 keepSameRows = !splitRC.first;
190 keepSameCols = !splitRC.second;
191 isLower = (symFlag == kLowerSymmetric ? true : false);
192 for (int i = 0; i < nrChildRow(); ++i) {
193 // Don't recurse on rows if splitRowsCols() told us not to.
194 ClusterTree* rowChild = const_cast<ClusterTree*>((keepSameRows ? rows_ : rows_->getChild(i)));
195 for (int j = 0; j < nrChildCol(); ++j) {
196 // Don't recurse on cols if splitRowsCols() told us not to.
197 ClusterTree* colChild = const_cast<ClusterTree*>((keepSameCols ? cols_ : cols_->getChild(j)));
198 if ((symFlag == kNotSymmetric) || (isUpper && (i <= j)) || (isLower && (i >= j))) {
199 if (!admissibilityCondition->isInert(*rowChild, *colChild)) {
200 // Create child only if not 'inert' (inert = will always be null)
201 this->insertChild(i, j,
202 new HMatrix<T>(rowChild, colChild, localSettings.global,
203 this->depth + 1,
204 i == j ? symFlag : kNotSymmetric,
205 admissibilityCondition));
206 } else
207 // If 'inert', the child is NULL
208 this->insertChild(i, j, NULL);
209 }
210 }
211 }
212 if(nrChildRow() > 0 && nrChildCol() > 0)
213 rank_ = NONLEAF_BLOCK;
214 return true;
215 }
216
217 template<typename T>
HMatrix(const hmat::MatrixSettings * settings)218 HMatrix<T>::HMatrix(const hmat::MatrixSettings * settings) :
219 Tree<HMatrix<T> >(NULL), RecursionMatrix<T, HMatrix<T> >(), rows_(NULL), cols_(NULL),
220 rk_(NULL), rank_(UNINITIALIZED_BLOCK), approximateRank_(UNINITIALIZED_BLOCK),
221 isUpper(false), isLower(false), isTriUpper(false), isTriLower(false),
222 keepSameRows(true), keepSameCols(true), temporary_(false), ownRowsClusterTree_(false),
223 ownColsClusterTree_(false), localSettings(settings, -1.0)
224 {}
225
internalCopy(bool temporary,bool withRowChild,bool withColChild) const226 template<typename T> HMatrix<T> * HMatrix<T>::internalCopy(bool temporary, bool withRowChild, bool withColChild) const {
227 HMatrix<T> * r = new HMatrix<T>(localSettings.global);
228 r->rows_ = rows_;
229 r->cols_ = cols_;
230 r->temporary_ = temporary;
231 r->localSettings.epsilon_ = localSettings.epsilon_;
232 if(withRowChild || withColChild) {
233 // Here, we come from HMatrixHandle<T>::createGemmTemporaryRk()
234 // we want to go 1 level below data (which is an Rk)
235 // so we don't use get(i,j) since data has no children
236 // we dont use this->nrChildRow and this->nrChildCol either, they would return 1
237 // (since 'this' is rows- and cols-admissible, unlike 'r')
238 r->keepSameRows = !withRowChild;
239 r->keepSameCols = !withColChild;
240 for(int i = 0; i < r->nrChildRow(); i++) {
241 for(int j = 0; j < r->nrChildCol(); j++) {
242 HMatrix<T>* child = new HMatrix<T>(localSettings.global);
243 child->temporary_ = temporary;
244 child->rows_ = withRowChild ? rows_->getChild(i) : rows_;
245 child->cols_ = withColChild ? cols_->getChild(j) : cols_;
246 child->localSettings.epsilon_ = localSettings.epsilon_;
247 assert(child->rows_ != NULL);
248 assert(child->cols_ != NULL);
249 assert(child->localSettings.epsilon_ > 0);
250 child->rk(NULL);
251 r->insertChild(i, j, child);
252 }
253 }
254 }
255 return r;
256 }
257
internalCopy(const ClusterTree * rows,const ClusterTree * cols) const258 template<typename T> HMatrix<T>* HMatrix<T>::internalCopy(
259 const ClusterTree * rows, const ClusterTree * cols) const {
260 HMatrix<T> * r = new HMatrix<T>(localSettings.global);
261 r->temporary_ = true;
262 r->rows_ = rows;
263 r->cols_ = cols;
264 r->localSettings.epsilon_ = localSettings.epsilon_;
265 return r;
266 }
267
268 template<typename T>
copyStructure() const269 HMatrix<T>* HMatrix<T>::copyStructure() const {
270 HMatrix<T>* h = internalCopy();
271 h->isUpper = isUpper;
272 h->isLower = isLower;
273 h->isTriUpper = isTriUpper;
274 h->isTriLower = isTriLower;
275 h->keepSameRows = keepSameRows;
276 h->keepSameCols = keepSameCols;
277 h->rank_ = rank_ >= 0 ? 0 : rank_;
278 h->approximateRank_ = approximateRank_;
279 if(!this->isLeaf()){
280 for (int i = 0; i < this->nrChild(); ++i) {
281 if (this->getChild(i)) {
282 h->insertChild(i, this->getChild(i)->copyStructure());
283 }
284 else
285 h->insertChild(i, NULL);
286 }
287 }
288 return h;
289 }
290
291 template<typename T>
Zero(const HMatrix<T> * o)292 HMatrix<T>* HMatrix<T>::Zero(const HMatrix<T>* o) {
293 // leaves are filled by 0
294 HMatrix<T> *h = o->internalCopy();
295 h->isLower = o->isLower;
296 h->isUpper = o->isUpper;
297 h->isTriUpper = o->isTriUpper;
298 h->isTriLower = o->isTriLower;
299 h->keepSameRows = o->keepSameRows;
300 h->keepSameCols = o->keepSameCols;
301 h->rank_ = o->rank_ >= 0 ? 0 : o->rank_;
302 if (h->rank_==0)
303 h->rk(new RkMatrix<T>(NULL, h->rows(), NULL, h->cols()));
304 h->approximateRank_ = o->approximateRank_;
305 if(!o->isLeaf()){
306 for (int i = 0; i < o->nrChild(); ++i) {
307 if (o->getChild(i)) {
308 h->insertChild(i, HMatrix<T>::Zero(o->getChild(i)));
309 } else
310 h->insertChild(i, NULL);
311 }
312 }
313 return h;
314 }
315
316 template<typename T>
setClusterTrees(const ClusterTree * rows,const ClusterTree * cols)317 void HMatrix<T>::setClusterTrees(const ClusterTree* rows, const ClusterTree* cols) {
318 rows_ = rows;
319 cols_ = cols;
320 if(isRkMatrix() && rk()) {
321 rk()->rows = &(rows->data);
322 rk()->cols = &(cols->data);
323 } else if(isFullMatrix()) {
324 full()->rows_ = &(rows->data);
325 full()->cols_ = &(cols->data);
326 } else if(!this->isLeaf()) {
327 for (int i = 0; i < nrChildRow(); ++i) {
328 // if rows not admissible, don't recurse on them
329 const ClusterTree* rowChild = (keepSameRows ? rows : rows->me()->getChild(i));
330 for (int j = 0; j < nrChildCol(); ++j) {
331 // if cols not admissible, don't recurse on them
332 const ClusterTree* colChild = (keepSameCols ? cols : cols->me()->getChild(j));
333 if(get(i, j))
334 get(i, j)->setClusterTrees(rowChild, colChild);
335 }
336 }
337 }
338 }
339
340 template<typename T>
assemble(Assembly<T> & f,const AllocationObserver & ao)341 void HMatrix<T>::assemble(Assembly<T>& f, const AllocationObserver & ao) {
342 if (this->isLeaf()) {
343 // If the leaf is admissible, matrix assembly and compression.
344 // if not we keep the matrix.
345 FullMatrix<T> * m = NULL;
346 RkMatrix<T>* assembledRk = NULL;
347 f.assemble(localSettings, *rows_, *cols_, isRkMatrix(), m, assembledRk, lowRankEpsilon(), ao);
348 HMAT_ASSERT(m == NULL || assembledRk == NULL);
349 if(assembledRk) {
350 assert(isRkMatrix());
351 if(rk_)
352 delete rk_;
353 rk(assembledRk);
354 } else {
355 assert(!isRkMatrix());
356 if(full_)
357 delete full_;
358 full(m);
359 }
360 } else {
361 full_ = NULL;
362 rk_ = NULL;
363 for (int i = 0; i < this->nrChild(); i++) {
364 if (this->getChild(i))
365 this->getChild(i)->assemble(f, ao);
366 }
367 assembledRecurse();
368 if (coarsening)
369 coarsen(RkMatrix<T>::approx.coarseningEpsilon);
370 }
371 }
372
373 template<typename T>
assembleSymmetric(Assembly<T> & f,HMatrix<T> * upper,bool onlyLower,const AllocationObserver & ao)374 void HMatrix<T>::assembleSymmetric(Assembly<T>& f,
375 HMatrix<T>* upper, bool onlyLower, const AllocationObserver & ao) {
376 if (!onlyLower) {
377 if (!upper){
378 upper = this;
379 }
380 assert(*this->rows() == *upper->cols());
381 assert(*this->cols() == *upper->rows());
382 }
383
384 if (this->isLeaf()) {
385 // If the leaf is admissible, matrix assembly and compression.
386 // if not we keep the matrix.
387 this->assemble(f, ao);
388 if (isRkMatrix()) {
389 if ((!onlyLower) && (upper != this)) {
390 // Admissible leaf: a matrix represented by AB^t is transposed by exchanging A and B.
391 RkMatrix<T>* newRk = rk()->copy();
392 newRk->transpose();
393 if(upper->isRkMatrix() && upper->rk() != NULL)
394 delete upper->rk();
395 upper->rk(newRk);
396 }
397 } else {
398 if ((!onlyLower) && ( upper != this)) {
399 if(isFullMatrix())
400 upper->full(full()->copyAndTranspose());
401 else
402 upper->full(NULL);
403 }
404 }
405 } else {
406 if (onlyLower) {
407 for (int i = 0; i < nrChildRow(); i++) {
408 for (int j = 0; j < nrChildCol(); j++) {
409 if ((*rows() == *cols()) && (j > i)) {
410 continue;
411 }
412 if (get(i,j))
413 get(i,j)->assembleSymmetric(f, NULL, true, ao);
414 }
415 }
416 } else {
417 if (this == upper) {
418 for (int i = 0; i < nrChildRow(); i++) {
419 for (int j = 0; j <= i; j++) {
420 HMatrix<T> *child = get(i, j);
421 HMatrix<T> *upperChild = get(j, i);
422 assert((child != NULL) == (upperChild != NULL));
423 if (child)
424 child->assembleSymmetric(f, upperChild, false, ao);
425 }
426 }
427 } else {
428 for (int i = 0; i < nrChildRow(); i++) {
429 for (int j = 0; j < nrChildCol(); j++) {
430 HMatrix<T> *child = get(i, j);
431 HMatrix<T> *upperChild = upper->get(j, i);
432 assert((child != NULL) == (upperChild != NULL));
433 if (child)
434 child->assembleSymmetric(f, upperChild, false, ao);
435 }
436 }
437 upper->assembledRecurse();
438 if (coarsening)
439 coarsen(RkMatrix<T>::approx.coarseningEpsilon, upper);
440 }
441 }
442 assembledRecurse();
443 }
444 }
445
info(hmat_info_t & result)446 template<typename T> void HMatrix<T>::info(hmat_info_t & result) {
447 result.nr_block_clusters++;
448 int r = rows()->size();
449 int c = cols()->size();
450 if(r == 0 || c == 0) {
451 return;
452 } else if(this->isLeaf()) {
453 size_t s = ((size_t)r) * c;
454 result.uncompressed_size += s;
455 if(isRkMatrix()) {
456 size_t mem = rank() * (((size_t)r) + c);
457 result.compressed_size += mem;
458 int dim = result.largest_rk_dim_cols + result.largest_rk_dim_rows;
459 if(rows()->size() + cols()->size() > dim) {
460 result.largest_rk_dim_cols = c;
461 result.largest_rk_dim_rows = r;
462 }
463
464 size_t old_s = ((size_t)result.largest_rk_mem_cols + result.largest_rk_mem_rows) * result.largest_rk_mem_rank;
465 if(mem > old_s) {
466 result.largest_rk_mem_cols = c;
467 result.largest_rk_mem_rows = r;
468 result.largest_rk_mem_rank = rank();
469 }
470 result.rk_count++;
471 result.rk_size += s;
472 } else {
473 result.compressed_size += s;
474 result.full_count ++;
475 result.full_size += s;
476 }
477 } else {
478 for (int i = 0; i < this->nrChild(); i++) {
479 HMatrix<T> *child = this->getChild(i);
480 if (child)
481 child->info(result);
482 }
483 }
484 }
485
486 template<typename T>
eval(FullMatrix<T> * result,bool renumber) const487 void HMatrix<T>::eval(FullMatrix<T>* result, bool renumber) const {
488 if (this->isLeaf()) {
489 if (this->isNull()) return;
490 FullMatrix<T> *mat = isRkMatrix() ? rk()->eval() : full();
491 int *rowIndices = rows()->indices() + rows()->offset();
492 int rowCount = rows()->size();
493 int *colIndices = cols()->indices() + cols()->offset();
494 int colCount = cols()->size();
495 if(renumber) {
496 for (int j = 0; j < colCount; j++)
497 for (int i = 0; i < rowCount; i++)
498 result->get(rowIndices[i], colIndices[j]) = mat->get(i, j);
499 } else {
500 for (int j = 0; j < colCount; j++)
501 memcpy(&result->get(rows()->offset(), cols()->offset() + j), &mat->get(0, j), rowCount * sizeof(T));
502 }
503 if (isRkMatrix()) {
504 delete mat;
505 }
506 } else {
507 for (int i = 0; i < this->nrChild(); i++) {
508 if (this->getChild(i)) {
509 this->getChild(i)->eval(result, renumber);
510 }
511 }
512 }
513 }
514
515 template<typename T>
evalPart(FullMatrix<T> * result,const IndexSet * _rows,const IndexSet * _cols) const516 void HMatrix<T>::evalPart(FullMatrix<T>* result, const IndexSet* _rows,
517 const IndexSet* _cols) const {
518 if (this->isLeaf()) {
519 if (this->isNull()) return;
520 FullMatrix<T> *mat = isRkMatrix() ? rk()->eval() : full();
521 const int rowOffset = rows()->offset() - _rows->offset();
522 const int rowCount = rows()->size();
523 const int colOffset = cols()->offset() - _cols->offset();
524 const int colCount = cols()->size();
525 for (int j = 0; j < colCount; j++) {
526 memcpy(&result->get(rowOffset, j + colOffset), &mat->get(0, j), rowCount * sizeof(T));
527 }
528 if (isRkMatrix()) {
529 delete mat;
530 }
531 } else {
532 for (int i = 0; i < this->nrChild(); i++) {
533 if (this->getChild(i)) {
534 this->getChild(i)->evalPart(result, _rows, _cols);
535 }
536 }
537 }
538 }
539
normSqr() const540 template<typename T> double HMatrix<T>::normSqr() const {
541 double result = 0.;
542 if (rows()->size() == 0 || cols()->size() == 0) {
543 return result;
544 }
545 if (this->isLeaf() && isAssembled() && !isNull()) {
546 if (isRkMatrix()) {
547 result = rk()->normSqr();
548 } else {
549 result = full()->normSqr();
550 }
551 } else if(!this->isLeaf()){
552 for (int i = 0; i < this->nrChild(); i++) {
553 const HMatrix<T> *res=this->getChild(i);
554 if (res) {
555 // When computing the norm of symmetric matrices, extra-diagonal blocks count twice
556 double coeff = (isUpper || isLower) && ! (*res->rows() == *res->cols()) ? 2. : 1. ;
557 result += coeff * res->normSqr();
558 }
559 }
560 }
561 return result;
562 }
563
564 // Return an approximation of the largest eigenvalue via the power method.
565 // If needed, we could also return the corresponding eigenvector.
approximateLargestEigenvalue(int max_iter,double epsilon) const566 template<typename T> T HMatrix<T>::approximateLargestEigenvalue(int max_iter, double epsilon) const {
567 if (max_iter <= 0) return 0.0;
568 if (rows()->size() == 0 || cols()->size() == 0) {
569 return 0.0;
570 }
571 const int nrow = rows()->size();
572 Vector<T> xv(nrow);
573 Vector<T> xv1(nrow);
574 Vector<T>* x = &xv;
575 Vector<T>* x1 = &xv1;
576 T ev = Constants<T>::zero;
577 for (int i = 0; i < nrow; i++)
578 xv[i] = static_cast<T>(rand()/(double)RAND_MAX);
579 double normx = x->norm();
580 if (normx == 0.0)
581 return approximateLargestEigenvalue(max_iter - 1, epsilon);
582 x->scale(static_cast<T>(1.0/normx));
583 int iter = 0;
584 double aev = 0.0;
585 double aev_p = 0.0;
586 do {
587 // old eigenvalue
588 aev_p = aev;
589 // Compute x(k+1) = A x(k)
590 // ev(k+1) = <x(k+1),x(k)>
591 // x(k+1) = x(k+1) / ||x(k+1)||
592 gemv('N', Constants<T>::pone, x, Constants<T>::zero, x1);
593 ev = Vector<T>::dot(x,x1);
594 // new abs(ev)
595 aev = std::abs(ev);
596 normx = x1->norm();
597 // If x1 is null, restart so that starting point is different.
598 // Decrease max_iter to prevent infinite recursion.
599 if (normx == 0.0)
600 return approximateLargestEigenvalue(max_iter - 1, epsilon);
601 x1->scale(static_cast<T>(1.0/normx));
602 std::swap(x,x1);
603 iter++;
604 } while(iter < max_iter && std::abs(aev - aev_p) > epsilon * aev);
605 return ev;
606 }
607
608
609 template<typename T>
scale(T alpha)610 void HMatrix<T>::scale(T alpha) {
611 if(alpha == Constants<T>::zero) {
612 this->clear();
613 } else if(alpha == Constants<T>::pone) {
614 return;
615 } else if (this->isLeaf()) {
616 if (isNull()) {
617 // nothing to do
618 } else if (isRkMatrix()) {
619 rk()->scale(alpha);
620 } else {
621 assert(isFullMatrix());
622 full()->scale(alpha);
623 }
624 } else {
625 for (int i = 0; i < this->nrChild(); i++) {
626 if (this->getChild(i)) {
627 this->getChild(i)->scale(alpha);
628 }
629 }
630 }
631 }
632
633 template<typename T>
coarsen(double epsilon,HMatrix<T> * upper,bool force)634 bool HMatrix<T>::coarsen(double epsilon, HMatrix<T>* upper, bool force) {
635 // If all children are Rk leaves, then we try to merge them into a single Rk-leaf.
636 // This is done if the memory of the resulting leaf is less than the sum of the initial
637 // leaves. Note that this operation could be used hierarchically.
638
639 bool allRkLeaves = true;
640 const RkMatrix<T>* childrenArray[this->nrChild()];
641 size_t childrenElements = 0;
642 for (int i = 0; i < this->nrChild(); i++) {
643 childrenArray[i] = nullptr;
644 HMatrix<T> *child = this->getChild(i);
645 if (!child) continue;
646 if (!child->isRkMatrix()) {
647 allRkLeaves = false;
648 break;
649 } else {
650 childrenArray[i] = child->rk();
651 if(childrenArray[i])
652 childrenElements += (childrenArray[i]->rows->size()
653 + childrenArray[i]->cols->size()) * childrenArray[i]->rank();
654 }
655 }
656 if (allRkLeaves) {
657 std::vector<T> alpha(this->nrChild(), Constants<T>::pone);
658 RkMatrix<T> * candidate = new RkMatrix<T>(NULL, rows(), NULL, cols());
659 candidate->formattedAddParts(epsilon, &alpha[0], childrenArray, this->nrChild());
660 size_t elements = (((size_t) candidate->rows->size()) + candidate->cols->size()) * candidate->rank();
661 if (force || elements < childrenElements) {
662 // Replace 'this' by the new Rk matrix
663 for (int i = 0; i < this->nrChild(); i++)
664 this->removeChild(i);
665 this->children.clear();
666 rk(candidate);
667 assert(this->isLeaf());
668 assert(isRkMatrix());
669 // If necessary, replace 'upper' by the new Rk matrix transposed (exchange a and b)
670 if (upper) {
671 for (int i = 0; i < this->nrChild(); i++)
672 upper->removeChild(i);
673 upper->children.clear();
674 RkMatrix<T>* newRk = candidate->copy();
675 newRk->transpose();
676 upper->rk(newRk);
677 assert(upper->isLeaf());
678 assert(upper->isRkMatrix());
679 }
680 } else {
681 delete candidate;
682 }
683 }
684
685 return allRkLeaves;
686 }
687
getChildForGEMM(char & t,int i,int j) const688 template<typename T> const HMatrix<T> * HMatrix<T>::getChildForGEMM(char & t, int i, int j) const {
689 // At most 1 of these flags must be 'true'
690 assert(isUpper + isLower + isTriUpper + isTriLower >= -1);
691 assert(!this->isLeaf());
692
693 const HMatrix<T>* res;
694 if(t != 'N')
695 std::swap(i,j);
696 if( (isLower && j > i) ||
697 (isUpper && i > j) ) {
698 res = get(j, i);
699 t = t == 'N' ? 'T' : 'N';
700 } else {
701 res = get(i, j);
702 }
703 return res;
704 }
705
706 // y <- alpha * op(this) * x + beta * y or y <- alpha * x * op(this) + beta * y.
707 template<typename T>
gemv(char matTrans,T alpha,const ScalarArray<T> * x,T beta,ScalarArray<T> * y,Side side) const708 void HMatrix<T>::gemv(char matTrans, T alpha, const ScalarArray<T>* x, T beta, ScalarArray<T>* y, Side side) const {
709 if (rows()->size() == 0 || cols()->size() == 0) return;
710 // The dimensions of the H-matrix and the 2 ScalarArrays must match exactly
711 if(side == Side::LEFT) {
712 // Y + H * X
713 assert(x->cols == y->cols);
714 assert((matTrans != 'N' ? cols()->size() : rows()->size()) == y->rows);
715 assert((matTrans != 'N' ? rows()->size() : cols()->size()) == x->rows);
716 } else {
717 // Y + X * H
718 assert(x->rows == y->rows);
719 assert((matTrans != 'N' ? cols()->size() : rows()->size()) == x->cols);
720 assert((matTrans != 'N' ? rows()->size() : cols()->size()) == y->cols);
721 }
722 if (beta != Constants<T>::pone) {
723 y->scale(beta);
724 }
725
726 if (!this->isLeaf()) {
727 for (int i = 0, iend = (matTrans=='N' ? nrChildRow() : nrChildCol()); i < iend; i++)
728 for (int j = 0, jend = (matTrans=='N' ? nrChildCol() : nrChildRow()); j < jend; j++) {
729 char trans = matTrans;
730 // trans(child) = the child (i,j) of matTrans(this)
731 const HMatrix<T>* child = getChildForGEMM(trans, i, j);
732 if (child) {
733
734 // I get the rows and cols info of 'child'
735 int colsOffset = child->cols()->offset() - cols()->offset();
736 int colsSize = child->cols()->size();
737 int rowsOffset = child->rows()->offset() - rows()->offset();
738 int rowsSize = child->rows()->size();
739
740 // swap if needed to get the info for trans(child)
741 if (trans != 'N') {
742 std::swap(colsOffset, rowsOffset);
743 std::swap(colsSize, rowsSize);
744 }
745
746 if (side == Side::LEFT) {
747 // get the rows subset of X aligned with 'trans(child)' cols and Y aligned with 'trans(child)' rows
748 const ScalarArray<T> subX(*x, colsOffset, colsSize, 0, x->cols);
749 ScalarArray<T> subY(*y, rowsOffset, rowsSize, 0, y->cols);
750 child->gemv(trans, alpha, &subX, Constants<T>::pone, &subY, side);
751 } else {
752 // get the columns subset of X aligned with 'trans(child)' rows and Y aligned with 'trans(child)' columns
753 const ScalarArray<T> subX(*x, 0, x->rows, rowsOffset, rowsSize);
754 ScalarArray<T> subY(*y, 0, y->rows, colsOffset, colsSize);
755 child->gemv(trans, alpha, &subX, Constants<T>::pone, &subY, side);
756 }
757 }
758 else continue;
759 }
760
761 } else {
762 // We are on a leaf of the matrix 'this'
763 if (isFullMatrix()) {
764 if (side == Side::LEFT) {
765 y->gemm(matTrans, 'N', alpha, &full()->data, x, Constants<T>::pone);
766 } else {
767 y->gemm('N', matTrans, alpha, x, &full()->data, Constants<T>::pone);
768 }
769 } else if(!isNull()){
770 rk()->gemv(matTrans, alpha, x, Constants<T>::pone, y, side);
771 }
772 }
773 }
774
775 template<typename T>
gemv(char matTrans,T alpha,const FullMatrix<T> * x,T beta,FullMatrix<T> * y,Side side) const776 void HMatrix<T>::gemv(char matTrans, T alpha, const FullMatrix<T>* x, T beta, FullMatrix<T>* y, Side side) const {
777 gemv(matTrans, alpha, &x->data, beta, &y->data, side);
778 }
779
780 /**
781 * @brief List all Rk matrice in the m matrice.
782 * @return true if the matrix contains only rk matrices, fall if it contains
783 * both rk and full matrices
784 */
listAllRk(const HMatrix<T> * m,vector<const RkMatrix<T> * > & result)785 template<typename T> bool listAllRk(const HMatrix<T> * m, vector<const RkMatrix<T>*> & result) {
786 if(m == NULL) {
787 // do nothing
788 } else if(m->isRkMatrix())
789 result.push_back(m->rk());
790 else if(m->isLeaf())
791 return false;
792 else {
793 for(int i = 0; i < m->nrChild(); i++) {
794 if(m->getChild(i) && !listAllRk(m->getChild(i), result))
795 return false;
796 }
797 }
798 return true;
799 }
800
801 /**
802 * @brief generic AXPY implementation that dispatch to others or recurse
803 */
axpy(T alpha,const HMatrix<T> * x)804 template <typename T> void HMatrix<T>::axpy(T alpha, const HMatrix<T> *x) {
805 if (x->isLeaf()) {
806 if (x->isNull()) {
807 // nothing to do
808 } else if (x->isFullMatrix())
809 axpy(alpha, x->full());
810 else if (x->isRkMatrix())
811 axpy(alpha, x->rk());
812 } else {
813 HMAT_ASSERT(*rows() == *x->rows());
814 HMAT_ASSERT(*cols() == *x->cols());
815 if (this->isLeaf()) {
816 if (isRkMatrix()) {
817 if (!rk())
818 rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
819 vector<const RkMatrix<T> *> rkLeaves;
820 if (listAllRk(x, rkLeaves)) {
821 vector<T> alphas(rkLeaves.size(), alpha);
822 rk()->formattedAddParts(lowRankEpsilon(), &alphas[0], &rkLeaves[0], rkLeaves.size());
823 } else {
824 // x has contains both full and Rk matrices, this is not
825 // supported yet.
826 HMAT_ASSERT(false);
827 }
828 rank_ = rk()->rank();
829 } else {
830 if (full() == NULL)
831 full(new FullMatrix<T>(rows(), cols()));
832 FullMatrix<T> xFull(x->rows(), x->cols());
833 x->evalPart(&xFull, x->rows(), x->cols());
834 full()->axpy(alpha, &xFull);
835 }
836 } else
837 for (int i = 0; i < this->nrChild(); i++) {
838 HMatrix<T> *child = this->getChild(i);
839 const HMatrix<T> *bChild = x->isLeaf() ? x : x->getChild(i);
840 if (bChild != NULL) {
841 HMAT_ASSERT(child != NULL); // This may happen but this is not supported yet
842 child->axpy(alpha, bChild);
843 }
844 }
845 }
846 }
847
848 /** @brief AXPY between 'this' an H matrix and a subset of B with B a RkMatrix */
849 template<typename T>
axpy(T alpha,const RkMatrix<T> * b)850 void HMatrix<T>::axpy(T alpha, const RkMatrix<T>* b) {
851 DECLARE_CONTEXT;
852 // this += alpha * b
853 assert(b);
854 assert(b->rows->intersects(*rows()));
855 assert(b->cols->intersects(*cols()));
856
857 if (b->rank() == 0 || rows()->size() == 0 || cols()->size() == 0) {
858 return;
859 }
860
861 // If 'this' is not a leaf, we recurse with the same 'b'
862 if (!this->isLeaf()) {
863 for (int i = 0; i < this->nrChild(); i++) {
864 HMatrix<T> * c = this->getChild(i);
865 if (c) {
866 if(b->rank() < std::min(c->rows()->size(), c->cols()->size()) && b->rank() > 10) {
867 RkMatrix<T> * bc = b->truncatedSubset(c->rows(), c->cols(), c->lowRankEpsilon());
868 c->axpy(alpha, bc);
869 delete bc;
870 }
871 else c->axpy(alpha, b);
872 }
873 }
874 } else {
875 // To add-up a leaf to a RkMatrix, resizing may be necessary.
876 bool needResizing = b->rows->isStrictSuperSet(*rows())
877 || b->cols->isStrictSuperSet(*cols());
878 const RkMatrix<T>* newRk = b;
879 if (needResizing) {
880 newRk = b->subset(rows(), cols());
881 }
882 if (isRkMatrix()) {
883 if(!rk())
884 rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
885 rk()->axpy(lowRankEpsilon(), alpha, newRk);
886 rank_ = rk()->rank();
887 } else {
888 // In this case, the matrix has small size
889 // then evaluating the Rk-matrix is cheaper
890 FullMatrix<T>* rkMat = newRk->eval();
891 if(isFullMatrix()) {
892 full()->axpy(alpha, rkMat);
893 delete rkMat;
894 } else {
895 rkMat->scale(alpha);
896 full(rkMat);
897 }
898 }
899 if (needResizing) {
900 delete newRk;
901 }
902 }
903 }
904
905 /** @brief AXPY between 'this' an H matrix and a subset of B with B a FullMatrix */
906 template<typename T>
axpy(T alpha,const FullMatrix<T> * b)907 void HMatrix<T>::axpy(T alpha, const FullMatrix<T>* b) {
908 DECLARE_CONTEXT;
909 // this += alpha * b
910 assert(b->rows_->isSuperSet(*this->rows()) && b->cols_->isSuperSet(*this->cols()));
911
912 // If 'this' is not a leaf, we recurse with the same 'b'
913 if (!this->isLeaf()) {
914 for (int i = 0; i < this->nrChild(); i++) {
915 HMatrix<T>* child = this->getChild(i);
916 if (child)
917 child->axpy(alpha, b);
918 }
919 } else {
920 const FullMatrix<T>* subMat = b->subset(rows(), cols());
921 if (isRkMatrix()) {
922 if(!rk())
923 rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
924 rk()->axpy(lowRankEpsilon(), alpha, subMat);
925 rank_ = rk()->rank();
926 } else if(isFullMatrix()){
927 full()->axpy(alpha, subMat);
928 } else {
929 assert(!isAssembled() || full() == NULL);
930 full(subMat->copy());
931 if(alpha != Constants<T>::pone)
932 full()->scale(alpha);
933 }
934 delete subMat;
935 }
936 }
937
938 template<typename T>
addIdentity(T alpha)939 void HMatrix<T>::addIdentity(T alpha)
940 {
941 if (this->isLeaf()) {
942 if (isNull()) {
943 HMAT_ASSERT(!this->isRkMatrix());
944 full(new FullMatrix<T>(rows(), cols()));
945 }
946 if (isFullMatrix()) {
947 FullMatrix<T> * b = full();
948 assert(b->rows() == b->cols());
949 for (int i = 0; i < b->rows(); i++) {
950 b->get(i, i) += alpha;
951 }
952 } else {
953 HMAT_ASSERT(false);
954 }
955 } else {
956 for (int i = 0; i < nrChildRow(); i++)
957 if(get(i, i) != nullptr)
958 get(i,i)->addIdentity(alpha);
959 }
960 }
961
962 template<typename T>
addRand(double epsilon)963 void HMatrix<T>::addRand(double epsilon)
964 {
965 if (this->isLeaf()) {
966 if (isFullMatrix()) {
967 full()->addRand(epsilon);
968 } else {
969 rk()->addRand(epsilon);
970 }
971 } else {
972 for (int i = 0; i < nrChildRow(); i++) {
973 for(int j = 0; j < nrChildCol(); j++) {
974 if(get(i,j)) {
975 get(i,j)->addRand(epsilon);
976 }
977 }
978 }
979 }
980 }
981
subset(const IndexSet * rows,const IndexSet * cols) const982 template<typename T> HMatrix<T> * HMatrix<T>::subset(
983 const IndexSet * rows, const IndexSet * cols) const
984 {
985 if((this->rows() == rows && this->cols() == cols) ||
986 (*(this->rows()) == *rows && *(this->cols()) == *cols) ||
987 (!rows->isSubset(*(this->rows())) || !cols->isSubset(*(this->cols())))) // TODO cette ligne me parait louche... si rows et cols sont pas bons, on renvoie 'this' sans meme se plaindre ???
988 return const_cast<HMatrix<T>*>(this);
989
990 // this could be implemented but if you need it you more
991 // likely have something to fix at a higher level.
992 assert(!this->isNull());
993
994 if(this->isLeaf()) {
995 HMatrix<T> * tmpMatrix = new HMatrix<T>(this->localSettings.global);
996 tmpMatrix->temporary_=true;
997 tmpMatrix->localSettings.epsilon_ = localSettings.epsilon_;
998 ClusterTree * r = rows_->slice(rows->offset(), rows->size());
999 ClusterTree * c = cols_->slice(cols->offset(), cols->size());
1000
1001 // ensure the cluster tree are properly freed
1002 r->father = r;
1003 c->father = c;
1004
1005 tmpMatrix->rows_ = r;
1006 tmpMatrix->cols_ = c;
1007 tmpMatrix->ownClusterTrees(true, true);
1008 if(this->isRkMatrix()) {
1009 tmpMatrix->rk(const_cast<RkMatrix<T>*>(rk()->subset(tmpMatrix->rows(), tmpMatrix->cols())));
1010 } else {
1011 tmpMatrix->full(const_cast<FullMatrix<T>*>(full()->subset(tmpMatrix->rows(), tmpMatrix->cols())));
1012 }
1013 return tmpMatrix;
1014 } else {
1015 // 'This' is not a leaf
1016 //TODO not yet implemented but should not happen
1017 HMAT_ASSERT(false);
1018 }
1019 }
1020
1021 /**
1022 * @brief Ensure that matrices have compatible cluster trees.
1023 * @param row_a If true check the number of row of A is compatible else check columns
1024 * @param row_b If true check the number of row of B is compatible else check columns
1025 * @param in_a The input A matrix whose dimension must be checked
1026 * @param in_b The input B matrix whose dimension must be checked
1027 * @param out_a A subset of the A matrix which have compatible dimension with out_b.
1028 * out_a is a view on in_a, no data are copied. It can possibly return in_a if matrices
1029 * are already compatibles.
1030 * @param out_b A subset of the B matrix which have compatible dimension with out_a.
1031 * out_b is a view on in_b, no data are copied. It can possibly return in_b if matrices
1032 * are already compatibles.
1033 */
1034 template<typename T> void
makeCompatible(bool row_a,bool row_b,const HMatrix<T> * in_a,const HMatrix<T> * in_b,HMatrix<T> * & out_a,HMatrix<T> * & out_b)1035 makeCompatible(bool row_a, bool row_b,
1036 const HMatrix<T> * in_a, const HMatrix<T> * in_b,
1037 HMatrix<T> * & out_a, HMatrix<T> * & out_b) {
1038
1039 // suppose that A is bigger than B: in that case A will change, not B
1040 const IndexSet * cdb = row_b ? in_b->rows() : in_b->cols();
1041 if(row_a) // restrict the rows of in_a to cdb
1042 out_a = in_a->subset(cdb, in_a->cols());
1043 else // or the cols
1044 out_a = in_a->subset(in_a->rows(), cdb);
1045
1046 // if A has changed, B won't change so we bypass this second step
1047 if(out_a == in_a) {
1048 // suppose than B is bigger than A: B will change, not A
1049 const IndexSet * cda = row_a ? in_a->rows() : in_a->cols();
1050 if(row_b)
1051 out_b = in_b->subset(cda, in_b->cols());
1052 else
1053 out_b = in_b->subset(in_b->rows(), cda);
1054 }
1055 else
1056 out_b = const_cast<HMatrix<T> *>(in_b);
1057 }
1058
1059 /**
1060 * @brief A GEMM implementation which do not require matrices have compatible
1061 * cluster tree.
1062 *
1063 * We compute the product alpha.f(a).f(b)+c -> c (with c=this)
1064 * f(a)=transpose(a) if transA='T', f(a)=a if transA='N' (idem for b)
1065 */
uncompatibleGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1066 template<typename T> void HMatrix<T>::uncompatibleGemm(char transA, char transB, T alpha,
1067 const HMatrix<T>* a, const HMatrix<T>* b) {
1068 // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1069 if(isVoid() || a->isVoid())
1070 return;
1071 HMatrix<T> * va = NULL;
1072 HMatrix<T> * vb = NULL;
1073 HMatrix<T> * vc = NULL;;
1074 HMatrix<T> * vva = NULL;
1075 HMatrix<T> * vvb = NULL;
1076 HMatrix<T> * vvc = NULL;
1077
1078 // Create va & vb = the subsets of a & b that match each other for doing the product f(a).f(b)
1079 // We modify the columns of f(a) and the rows of f(b)
1080 makeCompatible<T>(transA != 'N', transB == 'N', a, b, va, vb);
1081
1082 if(this->isLeaf() && !this->isRkMatrix() && this->full() == NULL) {
1083 // C (this) is a null full block. We cannot get the subset of it and we
1084 // don't know yet if we need to allocate it
1085 fullHHGemm(this, transA, transB, alpha, va, vb);
1086 if(va != a)
1087 delete va;
1088 if(vb != b)
1089 delete vb;
1090 return;
1091 } else {
1092 // Create vva & vc = the subsets of va & c (=this) that match each other for doing the sum c+f(a).f(b)
1093 // We modify the rows of f(a) and the rows of c
1094 makeCompatible<T>(transA == 'N', true, va, this, vva, vc);
1095
1096 // Create vvb & vvc = the subsets of vb & vc that match each other for doing the sum c+f(a).f(b)
1097 // We modify the columns of f(b) and the columns of c
1098 makeCompatible<T>(transB != 'N', false, vb, vc, vvb, vvc);
1099 }
1100
1101 // Delete the intermediate matrices, except if subset() in makecompatible() has returned the original matrix
1102 if(va != vva && va != a)
1103 delete va;
1104 if(vb != vvb && vb != b)
1105 delete vb;
1106 if(vc != vvc && vc != this)
1107 delete vc;
1108
1109 // writing on a subset of an RkMatrix is not possible without
1110 // modifying the whole matrix
1111 assert(!isRkMatrix() || vvc == this);
1112 // Do the product on the matrices that are now compatible
1113 vvc->leafGemm(transA, transB, alpha, vva, vvb);
1114
1115 // Delete the temporary matrices
1116 if(vva != a)
1117 delete vva;
1118 if(vvb != b)
1119 delete vvb;
1120 if(vvc != this)
1121 delete vvc;
1122 }
1123
1124 template<typename T>
compatibilityGridForGEMM(const HMatrix<T> * a,Axis axisA,char transA,const HMatrix<T> * b,Axis axisB,char transB)1125 unsigned char * compatibilityGridForGEMM(const HMatrix<T>* a, Axis axisA, char transA, const HMatrix<T>* b, Axis axisB, char transB) {
1126 // Let us first consider C = A^T * B where A, B and C are top-level matrices:
1127 // [ C11 | C12 ] [ A11^T | A21^T ] [ B11 | B12 ]
1128 // [ ----+---- ] = [ ------+------ ] * [ ----+---- ]
1129 // [ C21 | C22 ] [ A12^T | A22^T ] [ B21 | B22 ]
1130 // This multiplication is possible only if columns of A^T and rows of B are the same,
1131 // rows of A^T and rows of C are the same, and columns of B and columns of C are the same.
1132 // Matrices are built from the same cluster trees, so we know that blocks are split at the
1133 // same place, and this function will return:
1134 // compatibilityGridForGEMM(A, COL, 'T', B, ROW, 'N') = {1, 0, 0, 1}
1135 // compatibilityGridForGEMM(A, ROW, 'T', C, ROW, 'N') = {1, 0, 0, 1}
1136 // compatibilityGridForGEMM(B, COL, 'N', C, ROW, 'N') = {1, 0, 0, 1}
1137 // But blocks could be split in a single direction instead of 2, or could not
1138 // be split; if A had been split only by rows, we would have
1139 // [ C11 | C12 ] [ B11 | B12 ]
1140 // [ ----+---- ] = [ A11^T | A21^T ] * [ ----+---- ]
1141 // [ C21 | C22 ] [ B21 | B22 ]
1142 // compatibilityGridForGEMM(A, ROW, 'T', C, ROW, 'N') = {1, 1}
1143 //
1144 // Situation is much more complicated when considering inner nodes; for instance let us
1145 // have a look at A11^T * B11 with A11 and B11 being defined just above. Rows of A11^T
1146 // are equal to rows(C11)+rows(C21), and thus (considering that A11 and C11 are split
1147 // into 4 nodes)
1148 // compatibilityGridForGEMM(A11, ROW, 'T', C11, ROW, 'N') = {1, 1, 0, 0}
1149 //
1150 // This function is generic, it works for all cases as long as matrices are built from
1151 // the same cluster trees.
1152
1153 int row_a = transA == 'N' ? a->nrChildRow() : a->nrChildCol();
1154 int col_a = transA == 'N' ? a->nrChildCol() : a->nrChildRow();
1155 int row_b = transB == 'N' ? b->nrChildRow() : b->nrChildCol();
1156 int col_b = transB == 'N' ? b->nrChildCol() : b->nrChildRow();
1157 size_t nr_blocks = (axisA == Axis::ROW ? row_a : col_a) * (axisB == Axis::ROW ? row_b : col_b);
1158 unsigned char * result = new unsigned char[nr_blocks];
1159 memset(result, 0, nr_blocks);
1160
1161 if (axisA == Axis::ROW) {
1162 for (int iA = 0; iA < row_a; iA++) {
1163 // All children on a row have the same row cluster tree, get it
1164 // from the first non null child. We also consider the case where
1165 // 'a' is a leaf.
1166 const HMatrix<T> *childA = a->isLeaf() ? a : nullptr;
1167 char tA = transA;
1168 for (int jA = 0; !childA && jA < col_a; jA++) {
1169 tA = transA;
1170 childA = a->getChildForGEMM(tA, iA, jA);
1171 }
1172 // This row contains only null children, skip it
1173 if(!childA)
1174 continue;
1175 if (axisB == Axis::ROW) {
1176 for (int iB = 0; iB < row_b; iB++) {
1177 for (int jB = 0; jB < col_b; jB++) {
1178 char tB = transB;
1179 const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1180 if(childB) {
1181 result[iA * row_b + iB] = (tA == 'N' ? childA->rows() : childA->cols())->intersects(*(tB == 'N' ? childB->rows() : childB->cols()));
1182 break;
1183 }
1184 }
1185 }
1186 } else {
1187 for (int jB = 0; jB < col_b; jB++) {
1188 for (int iB = 0; iB < row_b; iB++) {
1189 char tB = transB;
1190 const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1191 if(childB) {
1192 result[iA * col_b + jB] = (tA == 'N' ? childA->rows() : childA->cols())->intersects(*(tB == 'N' ? childB->cols() : childB->rows()));
1193 break;
1194 }
1195 }
1196 }
1197 }
1198 }
1199 } else {
1200 for (int jA = 0; jA < col_a; jA++) {
1201 const HMatrix<T> *childA = a->isLeaf() ? a : nullptr;
1202 char tA = transA;
1203 for (int iA = 0; !childA && iA < row_a; iA++) {
1204 tA = transA;
1205 childA = a->getChildForGEMM(tA, iA, jA);
1206 }
1207 // This column contains only null children, skip it
1208 if(!childA)
1209 continue;
1210 if (axisB == Axis::ROW) {
1211 for (int iB = 0; iB < row_b; iB++) {
1212 for (int jB = 0; jB < col_b; jB++) {
1213 char tB = transB;
1214 const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1215 if(childB) {
1216 result[jA * row_b + iB] = (tA == 'N' ? childA->cols() : childA->rows())->intersects(*(tB == 'N' ? childB->rows() : childB->cols()));
1217 break;
1218 }
1219 }
1220 }
1221 } else {
1222 for (int jB = 0; jB < col_b; jB++) {
1223 for (int iB = 0; iB < row_b; iB++) {
1224 char tB = transB;
1225 const HMatrix<T> *childB = b->isLeaf() ? b : b->getChildForGEMM(tB, iB, jB);
1226 if(childB) {
1227 result[jA * col_b + jB] = (tA == 'N' ? childA->cols() : childA->rows())->intersects(*(tB == 'N' ? childB->cols() : childB->rows()));
1228 break;
1229 }
1230 }
1231 }
1232 }
1233 }
1234 }
1235 return result;
1236 }
1237
1238 template<typename T> void
recursiveGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1239 HMatrix<T>::recursiveGemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1240 // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1241 if(isVoid() || a->isVoid())
1242 return;
1243
1244 // None of the matrices is a leaf
1245 if (!this->isLeaf() && !a->isLeaf() && !b->isLeaf()) {
1246 int row_a = transA == 'N' ? a->nrChildRow() : a->nrChildCol();
1247 int col_a = transA == 'N' ? a->nrChildCol() : a->nrChildRow();
1248 int row_b = transB == 'N' ? b->nrChildRow() : b->nrChildCol();
1249 int col_b = transB == 'N' ? b->nrChildCol() : b->nrChildRow();
1250 int row_c = nrChildRow();
1251 int col_c = nrChildCol();
1252
1253 // There are 6 nested loops, this may be an issue if there are more
1254 // than 2 children in each direction; precompute compatibility between
1255 // blocks to improve performance:
1256 // + columns of a and rows of b
1257 unsigned char * is_compatible_a_b = compatibilityGridForGEMM(a, Axis::COL, transA, b, Axis::ROW, transB);
1258 // + rows of a and rows of c
1259 unsigned char * is_compatible_a_c = compatibilityGridForGEMM(a, Axis::ROW, transA, this, Axis::ROW, 'N');
1260 // + columns of b and columns of c
1261 unsigned char * is_compatible_b_c = compatibilityGridForGEMM(b, Axis::COL, transB, this, Axis::COL, 'N');
1262 // With these arrays, we can exit early from loops on iA, jB and l
1263 // when blocks are not compatible, and thus there are only 3 real
1264 // loops (on i, j, k) and performance penalty should be negligible.
1265 for (int i = 0; i < row_c; i++) {
1266 for (int j = 0; j < col_c; j++) {
1267 HMatrix<T>* child = get(i, j);
1268 if (!child) { // symmetric/triangular case or empty block coming from symbolic factorisation of sparse matrices
1269 continue;
1270 }
1271
1272 for (int iA = 0; iA < row_a; iA++) {
1273 if (!is_compatible_a_c[iA * row_c + i])
1274 continue;
1275 for (int jB = 0; jB < col_b; jB++) {
1276 if (!is_compatible_b_c[jB * col_c + j])
1277 continue;
1278 for (int k = 0; k < col_a; k++) {
1279 char tA = transA;
1280 const HMatrix<T> * childA = a->getChildForGEMM(tA, iA, k);
1281 if(!childA)
1282 continue;
1283 for (int l = 0; l < row_b; l++) {
1284 if (!is_compatible_a_b[k * row_b + l])
1285 continue;
1286 char tB = transB;
1287 const HMatrix<T> * childB = b->getChildForGEMM(tB, l, jB);
1288 if(childB)
1289 child->gemm(tA, tB, alpha, childA, childB, Constants<T>::pone);
1290 }
1291 }
1292 }
1293 }
1294 }
1295 }
1296 delete [] is_compatible_a_b;
1297 delete [] is_compatible_a_c;
1298 delete [] is_compatible_b_c;
1299 } // if (!this->isLeaf() && !a->isLeaf() && !b->isLeaf())
1300 else
1301 uncompatibleGemm(transA, transB, alpha, a, b);
1302 }
1303
1304 /**
1305 * @brief product of 2 H-matrix to a full block.
1306 *
1307 * Full blocks may be null and it's that case it's not possible to take
1308 * a subset from them. We don't want to allocate them before the recursion
1309 * because the recursion may show that allocation is not needed. This function
1310 * allow to recurse and bypass the uncompatibleGemm method which would create
1311 * invalid subset.
1312 */
fullHHGemm(HMatrix<T> * c,char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1313 template<typename T> void fullHHGemm(HMatrix<T> *c, char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1314 assert(c->isLeaf());
1315 assert(!c->isRkMatrix());
1316 if(!a->isLeaf() && !b->isLeaf()) {
1317 for (int i = 0; i < (transA=='N' ? a->nrChildRow() : a->nrChildCol()) ; i++) {
1318 for (int j = 0; j < (transB=='N' ? b->nrChildCol() : b->nrChildRow()) ; j++) {
1319 const HMatrix<T> *childA, *childB;
1320 for (int k = 0; k < (transA=='N' ? a->nrChildCol() : a->nrChildRow()) ; k++) {
1321 char tA = transA;
1322 char tB = transB;
1323 childA = a->getChildForGEMM(tA, i, k);
1324 childB = b->getChildForGEMM(tB, k, j);
1325 if(childA && childB)
1326 fullHHGemm(c, tA, tB, alpha, childA, childB);
1327 }
1328 }
1329 }
1330 } else if(!a->isRecursivelyNull() && !b->isRecursivelyNull()) {
1331 if(c->full() == NULL)
1332 c->full(new FullMatrix<T>(c->rows(), c->cols()));
1333 c->gemm(transA, transB, alpha, a, b, Constants<T>::pone);
1334 }
1335 }
1336
1337 template<typename T> void
leafGemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b)1338 HMatrix<T>::leafGemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>*b) {
1339 assert((transA == 'N' ? *a->cols() : *a->rows()) == ( transB == 'N' ? *b->rows() : *b->cols())); // pour le produit A*B
1340 assert((transA == 'N' ? *a->rows() : *a->cols()) == *this->rows()); // compatibility of A*B + this : Rows
1341 assert((transB == 'N' ? *b->cols() : *b->rows()) == *this->cols()); // compatibility of A*B + this : columns
1342
1343 // One of the matrices is a leaf
1344 assert(this->isLeaf() || a->isLeaf() || b->isLeaf());
1345
1346 // the resulting matrix is not a leaf.
1347 if (!this->isLeaf()) {
1348 // If the resulting matrix is subdivided then at least one of the matrices of the product is a leaf.
1349 // One matrix is a RkMatrix
1350 if (a->isRkMatrix() || b->isRkMatrix()) {
1351 if ((a->isRkMatrix() && a->isNull())
1352 || (b->isRkMatrix() && b->isNull())) {
1353 return;
1354 }
1355 RkMatrix<T>* rkMat = HMatrix<T>::multiplyRkMatrix(lowRankEpsilon(), transA, transB, a, b);
1356 axpy(alpha, rkMat);
1357 delete rkMat;
1358 } else {
1359 // None of the matrices of the product is a Rk-matrix so one of them is
1360 // a full matrix so as the result.
1361 assert(a->isFullMatrix() || b->isFullMatrix());
1362 FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix(transA, transB, a, b);
1363 if(fullMat) {
1364 axpy(alpha, fullMat);
1365 delete fullMat;
1366 }
1367 }
1368 return;
1369 }
1370
1371 if (isRkMatrix()) {
1372 // The resulting matrix is a RkMatrix leaf.
1373 // At least one of the matrix is not a leaf.
1374 // The different cases are :
1375 // a. R += H * H
1376 // b. R += H * R
1377 // c. R += R * H
1378 // d. R += H * M
1379 // e. R += M * H
1380 // f. R += M * M
1381
1382 // Cases a, b and c give an Hmatrix which has to be hierarchically converted into a Rkmatrix.
1383 // Cases c, d, e and f give a RkMatrix
1384 assert(isRkMatrix());
1385 assert((transA == 'N' ? *a->cols() : *a->rows()) == (transB == 'N' ? *b->rows() : *b->cols()));
1386 assert(*rows() == (transA == 'N' ? *a->rows() : *a->cols()));
1387 assert(*cols() == (transB == 'N' ? *b->cols() : *b->rows()));
1388 if(rk() == NULL)
1389 rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
1390 rk()->gemmRk(lowRankEpsilon(), transA, transB, alpha, a, b);
1391 rank_ = rk()->rank();
1392 return;
1393 }
1394
1395 // a, b are H matrices and 'this' is full
1396 if ( this->isLeaf() && ((!a->isLeaf() && !b->isLeaf()) || isNull()) ) {
1397 fullHHGemm(this, transA, transB, alpha, a, b);
1398 return;
1399 }
1400
1401 // The resulting matrix is a full matrix
1402 FullMatrix<T>* fullMat;
1403 if (a->isRkMatrix() || b->isRkMatrix()) {
1404 assert(a->isRkMatrix() || b->isRkMatrix());
1405 if ((a->isRkMatrix() && a->isNull())
1406 || (b->isRkMatrix() && b->isNull())) {
1407 return;
1408 }
1409 RkMatrix<T>* rkMat = HMatrix<T>::multiplyRkMatrix(lowRankEpsilon(), transA, transB, a, b);
1410 fullMat = rkMat->eval();
1411 delete rkMat;
1412 } else if(a->isLeaf() && b->isLeaf() && isFullMatrix()){
1413 full()->gemm(transA, transB, alpha, a->full(), b->full(), Constants<T>::pone);
1414 return;
1415 } else {
1416 // if a or b is a leaf, it is Full (since Rk have been treated before)
1417 fullMat = HMatrix<T>::multiplyFullMatrix(transA, transB, a, b);
1418 }
1419
1420 // It's not optimal to concider that the result is a FullMatrix but
1421 // this is a H*F case and it almost never happen
1422 if(fullMat) {
1423 if (isFullMatrix()) {
1424 full()->axpy(alpha, fullMat);
1425 delete fullMat;
1426 } else {
1427 full(fullMat);
1428 fullMat->scale(alpha);
1429 }
1430 }
1431 }
1432
1433 template<typename T>
gemm(char transA,char transB,T alpha,const HMatrix<T> * a,const HMatrix<T> * b,T beta,MainOp)1434 void HMatrix<T>::gemm(char transA, char transB, T alpha, const HMatrix<T>* a, const HMatrix<T>* b, T beta, MainOp) {
1435 // Computing a(m,0) * b(0,n) here may give wrong results because of format conversions, exit early
1436 if(isVoid() || a->isVoid())
1437 return;
1438
1439 // This and B are Rk matrices with the same panel 'b' -> the gemm is only applied on the panels 'a'
1440 if(isRkMatrix() && !isNull() && b->isRkMatrix() && !b->isNull() && rk()->b == b->rk()->b) {
1441 // Ca * CbT = beta * Ca * CbT + alpha * A * Ba * BbT
1442 // As Cb = Bb we get
1443 // Ca = beta * Ca + alpha A * Ba with only Ca and Ba scalar arrays
1444 // We support C and B not compatible (larger) with A so we first slice them
1445 assert(transB == 'N');
1446 const IndexSet * r = transA == 'N' ? a->rows() : a->cols();
1447 const IndexSet * c = transA == 'N' ? a->cols() : a->rows();
1448 ScalarArray<T> cSubset(rk()->a->rowsSubset( r->offset() - rows()->offset(), r->size()));
1449 ScalarArray<T> bSubset(b->rk()->a->rowsSubset( c->offset() - b->rows()->offset(), c->size()));
1450 a->gemv(transA, alpha, &bSubset, beta, &cSubset);
1451 return;
1452 }
1453
1454 // This and A are Rk matrices with the same panel 'a' -> the gemm is only applied on the panels 'b'
1455 if(isRkMatrix() && !isNull() && a->isRkMatrix() && !a->isNull() && rk()->a == a->rk()->a) {
1456 // Ca * CbT = beta * Ca * CbT + alpha * Aa * AbT * B
1457 // As Ca = Aa we get
1458 // CbT = beta * CbT + alpha AbT * B with only Cb and Ab scalar arrays
1459 // we transpose:
1460 // Cb = beta * Cb + alpha BT * Ab
1461 // We support C and B not compatible (larger) with A so we first slice them
1462 assert(transA == 'N');
1463 assert(transB != 'C');
1464 const IndexSet * r = transB == 'N' ? b->rows() : b->cols();
1465 const IndexSet * c = transB == 'N' ? b->cols() : b->rows();
1466 ScalarArray<T> cSubset(rk()->b->rowsSubset( c->offset() - cols()->offset(), c->size()));
1467 ScalarArray<T> aSubset(a->rk()->b->rowsSubset( r->offset() - a->cols()->offset(), r->size()));
1468 b->gemv(transB == 'N' ? 'T' : 'N', alpha, &aSubset, beta, &cSubset);
1469 return;
1470 }
1471
1472 this->scale(beta);
1473
1474 if((a->isLeaf() && (!a->isAssembled() || a->isNull())) ||
1475 (b->isLeaf() && (!b->isAssembled() || b->isNull()))) {
1476 if(!isAssembled() && this->isLeaf())
1477 rk(new RkMatrix<T>(NULL, rows(), NULL, cols()));
1478 return;
1479 }
1480
1481 // Once the scaling is done, beta is reset to 1
1482 // to avoid an other scaling.
1483 recursiveGemm(transA, transB, alpha, a, b);
1484 }
1485
1486 template<typename T>
multiplyFullH(char transM,char transH,const FullMatrix<T> * mat,const HMatrix<T> * h)1487 FullMatrix<T>* multiplyFullH(char transM, char transH,
1488 const FullMatrix<T>* mat,
1489 const HMatrix<T>* h) {
1490 assert(transH != 'C');
1491 FullMatrix<T>* resultT;
1492 if(transM == 'C') {
1493 // R = M* * H = (H^t * conj(M))^t
1494 FullMatrix<T>* matT = mat->copy();
1495 matT->conjugate();
1496 resultT = multiplyHFull(transH == 'N' ? 'T' : 'N',
1497 'N', h, matT);
1498 delete matT;
1499 } else {
1500 // R = M * H = (H^t * M^t*)^t
1501 resultT = multiplyHFull(transH == 'N' ? 'T' : 'N',
1502 transM == 'N' ? 'T' : 'N',
1503 h, mat);
1504 }
1505 if(resultT != NULL)
1506 resultT->transpose();
1507 return resultT;
1508 }
1509
isRecursivelyNull() const1510 template<typename T> bool HMatrix<T>::isRecursivelyNull() const {
1511 if(this->isLeaf())
1512 return isNull();
1513 else for(int i = 0; i < this->nrChild(); i++) {
1514 if(this->getChild(i) && !this->getChild(i)->isRecursivelyNull())
1515 return false;
1516 }
1517 return true;
1518 }
1519
1520 template<typename T>
multiplyHFull(char transH,char transM,const HMatrix<T> * h,const FullMatrix<T> * mat)1521 FullMatrix<T>* multiplyHFull(char transH, char transM,
1522 const HMatrix<T>* h,
1523 const FullMatrix<T>* mat) {
1524 assert((transH == 'N' ? h->cols()->size() : h->rows()->size())
1525 == (transM == 'N' ? mat->rows() : mat->cols()));
1526 if(h->isRecursivelyNull())
1527 return NULL;
1528 FullMatrix<T>* result =
1529 new FullMatrix<T>((transH == 'N' ? h->rows() : h->cols()),
1530 (transM == 'N' ? mat->cols_ : mat->rows_));
1531 if (transM == 'N') {
1532 h->gemv(transH, Constants<T>::pone, mat, Constants<T>::zero, result);
1533 } else {
1534 FullMatrix<T>* matT = mat->copyAndTranspose();
1535 if (transM == 'C') {
1536 matT->conjugate();
1537 }
1538 h->gemv(transH, Constants<T>::pone, matT, Constants<T>::zero, result);
1539 delete matT;
1540 }
1541 return result;
1542 }
1543
1544 template<typename T>
multiplyRkMatrix(double epsilon,char transA,char transB,const HMatrix<T> * a,const HMatrix<T> * b)1545 RkMatrix<T>* HMatrix<T>::multiplyRkMatrix(double epsilon, char transA, char transB, const HMatrix<T>* a, const HMatrix<T>* b){
1546 // We know that one of the matrices is a RkMatrix
1547 assert(a->isRkMatrix() || b->isRkMatrix());
1548 RkMatrix<T> *rk = NULL;
1549 // Matrices range compatibility
1550 if((transA == 'N') && (transB == 'N'))
1551 assert(a->cols()->size() == b->rows()->size());
1552 if((transA != 'N') && (transB == 'N'))
1553 assert(a->rows()->size() == b->rows()->size());
1554 if((transA == 'N') && (transB != 'N'))
1555 assert(a->cols()->size() == b->cols()->size());
1556
1557 // The cases are:
1558 // - A Rk, B H
1559 // - A H, B Rk
1560 // - A Rk, B Rk
1561 // - A Rk, B F
1562 // - A F, B Rk
1563 if (a->isRkMatrix() && !b->isLeaf()) {
1564 rk = RkMatrix<T>::multiplyRkH(transA, transB, a->rk(), b);
1565 HMAT_ASSERT(rk);
1566 }
1567 else if (!a->isLeaf() && b->isRkMatrix()) {
1568 rk = RkMatrix<T>::multiplyHRk(transA, transB, a, b->rk());
1569 HMAT_ASSERT(rk);
1570 }
1571 else if (a->isRkMatrix() && b->isRkMatrix()) {
1572 rk = RkMatrix<T>::multiplyRkRk(transA, transB, a->rk(), b->rk(), epsilon);
1573 HMAT_ASSERT(rk);
1574 }
1575 else if (a->isRkMatrix() && b->isFullMatrix()) {
1576 rk = RkMatrix<T>::multiplyRkFull(transA, transB, a->rk(), b->full());
1577 HMAT_ASSERT(rk);
1578 }
1579 else if (a->isFullMatrix() && b->isRkMatrix()) {
1580 rk = RkMatrix<T>::multiplyFullRk(transA, transB, a->full(), b->rk());
1581 HMAT_ASSERT(rk);
1582 } else if(a->isNull() || b->isNull()) {
1583 return new RkMatrix<T>(NULL, transA ? a->cols() : a->rows(),
1584 NULL, transB ? b->rows() : b->cols());
1585 } else {
1586 // None of the above cases, impossible.
1587 HMAT_ASSERT(false);
1588 }
1589 return rk;
1590 }
1591
1592 template<typename T>
multiplyFullMatrix(char transA,char transB,const HMatrix<T> * a,const HMatrix<T> * b)1593 FullMatrix<T>* HMatrix<T>::multiplyFullMatrix(char transA, char transB,
1594 const HMatrix<T>* a,
1595 const HMatrix<T>* b) {
1596 // At least one full matrix, and not RkMatrix.
1597 assert(a->isFullMatrix() || b->isFullMatrix());
1598 assert(!(a->isRkMatrix() || b->isRkMatrix()));
1599 FullMatrix<T> *result = NULL;
1600 // The cases are:
1601 // - A H, B F
1602 // - A F, B H
1603 // - A F, B F
1604 if (!a->isLeaf() && b->isFullMatrix()) {
1605 result = multiplyHFull(transA, transB, a, b->full());
1606 } else if (a->isFullMatrix() && !b->isLeaf()) {
1607 result = multiplyFullH(transA, transB, a->full(), b);
1608 } else if (a->isFullMatrix() && b->isFullMatrix()) {
1609 const IndexSet* aRows = (transA == 'N')? a->rows() : a->cols();
1610 const IndexSet* bCols = (transB == 'N')? b->cols() : b->rows();
1611 result = new FullMatrix<T>(aRows, bCols);
1612 result->gemm(transA, transB, Constants<T>::pone, a->full(), b->full(),
1613 Constants<T>::zero);
1614 } else if(a->isNull() || b->isNull()) {
1615 return NULL;
1616 } else {
1617 // None of above, impossible
1618 HMAT_ASSERT(false);
1619 }
1620 return result;
1621 }
1622
1623 template<typename T>
multiplyWithDiag(const HMatrix<T> * d,Side side,bool inverse) const1624 void HMatrix<T>::multiplyWithDiag(const HMatrix<T>* d, Side side, bool inverse) const {
1625 assert(*d->rows() == *d->cols());
1626 assert(side == Side::LEFT || (*cols() == *d->rows()));
1627 assert(side == Side::RIGHT || (*rows() == *d->cols()));
1628
1629 if (isVoid()) return;
1630
1631 // The symmetric matrix must be taken into account: lower or upper
1632 if (!this->isLeaf()) {
1633 if (d->isLeaf()) {
1634 for (int i=0 ; i<std::min(nrChildRow(), nrChildCol()) ; i++)
1635 if(get(i,i))
1636 get(i,i)->multiplyWithDiag(d, side, inverse);
1637 for (int i=0 ; i<nrChildRow() ; i++)
1638 for (int j=0 ; j<nrChildCol() ; j++)
1639 if (i!=j && get(i,j)) {
1640 get(i,j)->multiplyWithDiag(d, side, inverse);
1641 }
1642 return;
1643 }
1644
1645 // First the diagonal, then the rest...
1646 for (int i=0 ; i<std::min(nrChildRow(), nrChildCol()) ; i++)
1647 if(get(i,i))
1648 get(i,i)->multiplyWithDiag(d->get(i,i), side, inverse);
1649 for (int i=0 ; i<nrChildRow() ; i++)
1650 for (int j=0 ; j<nrChildCol() ; j++)
1651 if (i!=j && get(i,j)) {
1652 int k = side == Side::LEFT ? i : j;
1653 get(i,j)->multiplyWithDiag(d->get(k,k), side, inverse);
1654 // TODO couldn't we handle this case with the previous one, using getChildForGEMM(d,i,i) that returns 'd' itself when 'd' is a leaf ?
1655 }
1656 } else if (isRkMatrix() && !isNull()) {
1657 rk()->multiplyWithDiagOrDiagInv(d, inverse, side);
1658 } else if(isFullMatrix()){
1659 if (d->isFullMatrix()) {
1660 full()->multiplyWithDiagOrDiagInv(d->full()->diagonal, inverse, side);
1661 } else {
1662 Vector<T> diag(d->rows()->size());
1663 d->extractDiagonal(diag.ptr());
1664 full()->multiplyWithDiagOrDiagInv(&diag, inverse, side);
1665 }
1666 } else {
1667 // this is a null matrix (either full of Rk) so nothing to do
1668 }
1669 }
1670
transposeMeta(bool temporaryOnly)1671 template<typename T> void HMatrix<T>::transposeMeta(bool temporaryOnly) {
1672 if(temporaryOnly && !temporary_)
1673 return;
1674 // called by HMatrix<T>::transpose() and HMatrixHandle<T>::transpose()
1675 // if the matrix is symmetric, inverting it(Upper/Lower)
1676 if (isLower || isUpper) {
1677 isLower = !isLower;
1678 isUpper = !isUpper;
1679 }
1680 // if the matrix is triangular, on inverting it (isTriUpper/isTriLower)
1681 if (isTriLower || isTriUpper) {
1682 isTriLower = !isTriLower;
1683 isTriUpper = !isTriUpper;
1684 }
1685 // Warning: nrChildRow() uses keepSameRows and rows_
1686 bool tmp = keepSameCols; // can't use swap on bitfield so manual swap...
1687 keepSameCols = keepSameRows;
1688 keepSameRows = tmp;
1689 swap(rows_, cols_);
1690 RecursionMatrix<T, HMatrix<T> >::transposeMeta(temporaryOnly);
1691 }
1692
transposeData()1693 template <typename T> void HMatrix<T>::transposeData() {
1694 if (this->isLeaf()) {
1695 if (isRkMatrix() && rk()) {
1696 rk()->transpose();
1697 } else if (isFullMatrix()) {
1698 full()->transpose();
1699 }
1700 } else {
1701 for (int i = 0; i < this->nrChild(); i++)
1702 if (this->getChild(i))
1703 this->getChild(i)->transposeData();
1704 }
1705 }
1706
transpose()1707 template<typename T> void HMatrix<T>::transpose() {
1708 transposeData();
1709 transposeMeta();
1710 }
1711
conjugate()1712 template<> void HMatrix<S_t>::conjugate() {}
conjugate()1713 template<> void HMatrix<D_t>::conjugate() {}
conjugate()1714 template<typename T> void HMatrix<T>::conjugate() {
1715 std::vector<const HMatrix<T> *> stack;
1716 stack.push_back(this);
1717 while(!stack.empty()) {
1718 const HMatrix<T> * m = stack.back();
1719 stack.pop_back();
1720 if(!m->isLeaf()) {
1721 for(int i = 0; i < m->nrChild(); i++) {
1722 if(m->getChild(i) != NULL)
1723 stack.push_back(m->getChild(i));
1724 }
1725 } else if(m->isNull()) {
1726 // nothing to do
1727 } else if(m->isRkMatrix()) {
1728 m->rk()->conjugate();
1729 } else {
1730 m->full()->conjugate();
1731 }
1732 }
1733 }
1734
1735 template<typename T>
copyAndTranspose(const HMatrix<T> * o)1736 void HMatrix<T>::copyAndTranspose(const HMatrix<T>* o) {
1737 assert(o);
1738 assert(*this->rows() == *o->cols());
1739 assert(*this->cols() == *o->rows());
1740 assert(this->isLeaf() == o->isLeaf());
1741
1742 if (this->isLeaf()) {
1743 if (o->isRkMatrix()) {
1744 assert(!isFullMatrix());
1745 if (rk()) {
1746 delete rk();
1747 }
1748 RkMatrix<T>* newRk = o->rk()->copy();
1749 newRk->transpose();
1750 rk(newRk);
1751 } else {
1752 if (isFullMatrix()) {
1753 delete full();
1754 }
1755 const FullMatrix<T>* oF = o->full();
1756 if(oF == NULL) {
1757 full(NULL);
1758 } else {
1759 full(oF->copyAndTranspose());
1760 if (oF->diagonal) {
1761 if (!full()->diagonal) {
1762 full()->diagonal = new Vector<T>(oF->rows());
1763 HMAT_ASSERT(full()->diagonal);
1764 }
1765 oF->diagonal->copy(full()->diagonal);
1766 }
1767 }
1768 }
1769 } else {
1770 for (int i=0 ; i<nrChildRow() ; i++)
1771 for (int j=0 ; j<nrChildCol() ; j++)
1772 if (get(i,j) && o->get(j, i))
1773 get(i, j)->copyAndTranspose(o->get(j, i));
1774 }
1775 }
1776
1777 template<typename T>
truncate()1778 void HMatrix<T>::truncate() {
1779 if (this->isLeaf()) {
1780 if (this->isRkMatrix()) {
1781 if (rk()) {
1782 rk()->truncate(localSettings.epsilon_);
1783 rank_ = rk()->rank();
1784 }
1785 }
1786 } else {
1787 for (int i = 0; i < this->nrChild(); i++) {
1788 HMatrix<T>* child = this->getChild(i);
1789 if (child) {
1790 child->truncate();
1791 }
1792 }
1793 }
1794 }
1795
1796 template<typename T>
rows() const1797 const ClusterData* HMatrix<T>::rows() const {
1798 return &(rows_->data);
1799 }
1800
1801 template<typename T>
cols() const1802 const ClusterData* HMatrix<T>::cols() const {
1803 return &(cols_->data);
1804 }
1805
copy() const1806 template<typename T> HMatrix<T>* HMatrix<T>::copy() const {
1807 HMatrix<T>* M=Zero(this);
1808 M->copy(this);
1809 return M;
1810 }
1811
1812 // Copy the data of 'o' into 'this'
1813 // The structure of both H-matrix is supposed to be allready similar
1814 template<typename T>
copy(const HMatrix<T> * o)1815 void HMatrix<T>::copy(const HMatrix<T>* o) {
1816 DECLARE_CONTEXT;
1817
1818 assert(*rows() == *o->rows());
1819 assert(*cols() == *o->cols());
1820
1821 isLower = o->isLower;
1822 isUpper = o->isUpper;
1823 isTriUpper = o->isTriUpper;
1824 isTriLower = o->isTriLower;
1825 approximateRank_ = o->approximateRank_;
1826 if (this->isLeaf()) {
1827 assert(o->isLeaf());
1828 if (isAssembled() && isNull() && o->isNull()) {
1829 return;
1830 }
1831 // When the matrix was not allocated but only the structure
1832 if (o->isFullMatrix() && isFullMatrix()) {
1833 o->full()->copy(full());
1834 } else if(o->isFullMatrix()) {
1835 assert(!isAssembled() || isNull());
1836 full(o->full()->copy());
1837 } else if (o->isRkMatrix() && !rk()) {
1838 rk(new RkMatrix<T>(NULL, o->rk()->rows, NULL, o->rk()->cols));
1839 }
1840 assert((isRkMatrix() == o->isRkMatrix())
1841 && (isFullMatrix() == o->isFullMatrix()));
1842 if (o->isRkMatrix()) {
1843 rk()->copy(o->rk());
1844 rank_ = rk()->rank();
1845 }
1846 } else {
1847 assert(o->rank_==NONLEAF_BLOCK);
1848 rank_ = o->rank_;
1849 for (int i = 0; i < o->nrChild(); i++) {
1850 if (o->getChild(i)) {
1851 assert(this->getChild(i));
1852 this->getChild(i)->copy(o->getChild(i));
1853 } else {
1854 assert(!this->getChild(i));
1855 }
1856 }
1857 }
1858 }
1859
1860 template<typename T>
lowRankEpsilon(double epsilon,bool recursive)1861 void HMatrix<T>::lowRankEpsilon(double epsilon, bool recursive) {
1862 localSettings.epsilon_ = epsilon;
1863 if(recursive && !this->isLeaf()) {
1864 for (int i = 0; i < this->nrChild(); i++) {
1865 HMatrix<T>* child = this->getChild(i);
1866 if (child)
1867 child->lowRankEpsilon(epsilon);
1868 }
1869 }
1870 }
1871
clear()1872 template<typename T> void HMatrix<T>::clear() {
1873 if(!this->isLeaf()) {
1874 for (int i = 0; i < this->nrChild(); i++) {
1875 HMatrix<T>* child = this->getChild(i);
1876 if (child)
1877 child->clear();
1878 }
1879 } else if(isRkMatrix()) {
1880 if(rk())
1881 delete rk();
1882 rk(NULL);
1883 } else if(isFullMatrix()) {
1884 delete full();
1885 full(NULL);
1886 }
1887 }
1888
1889 template<typename T>
inverse()1890 void HMatrix<T>::inverse() {
1891 DECLARE_CONTEXT;
1892
1893 HMAT_ASSERT_MSG(!isLower, "HMatrix::inverse not available for symmetric matrices");
1894
1895 if (this->isLeaf()) {
1896 assert(isFullMatrix());
1897 full()->inverse();
1898 } else {
1899
1900 // Matrix inversion:
1901 // The idea to inverse M is to consider the extended matrix obtained by putting Identity next to M :
1902 //
1903 // [ M11 | M12 | I | 0 ]
1904 // [ ----+-----+-----+---- ]
1905 // [ M21 | M22 | 0 | I ]
1906 //
1907 // We then apply operations on the line of this matrix (matrix multiplication of an entire line,
1908 // linear combination of lines)
1909 // to transform the 'M' part into identity. Doing so, the identity part will at the end contain M-1.
1910 // We loop on the column of M.
1911 // At the end of loop 'k', the 'k' first columns of 'M' are now identity,
1912 // and the 'k' first columns of Identity have changed (it's no longer identity, it's not yet M-1).
1913 // The matrix 'this' stores the first 'k' block of the identity part of the extended matrix, and the last n-k blocks of the M part
1914 // At the end, 'this' contains M-1
1915
1916 if (isLower) {
1917
1918 vector<HMatrix<T>*> TM(nrChildCol());
1919 for (int k=0 ; k<nrChildRow() ; k++){
1920 // Inverse M_kk
1921 get(k,k)->inverse();
1922 // Update line 'k' = left-multiplied by M_kk-1
1923 for (int j=0 ; j<nrChildCol() ; j++) {
1924
1925 // Mkj <- Mkk^-1 Mkj we use a temp matrix X because this type of product is not allowed with gemm (beta=0 erases Mkj before using it !)
1926 if (j<k) { // under the diag we store TMj=Mkj
1927 TM[j] = get(k,j)->copy();
1928 get(k,j)->gemm('N', 'N', Constants<T>::pone, get(k,k), TM[j], Constants<T>::zero);
1929 } else if (j>k) { // above the diag : Mkj = t Mjk, we store TMj=-Mjk.tMkk-1 = -Mjk.Mkk-1 (Mkk est sym)
1930 TM[j] = Zero(get(j,k));
1931 TM[j]->gemm('N', 'T', Constants<T>::mone, get(j,k), get(k,k), Constants<T>::zero);
1932 }
1933 }
1934 // Update the rest of matrix M
1935 for (int i=0 ; i<nrChildRow() ; i++)
1936 // line 'i' -= Mik x line 'k' (which has just been multiplied by Mkk-1)
1937 for (int j=0 ; j<nrChildCol() ; j++)
1938 if (i!=k && j!=k && j<=i) {
1939 // Mij <- Mij - Mik (Mkk^-1 Mkj) (with Mkk-1.Mkj allready stored in Mkj and TMj=Mjk.tMkk-1)
1940 // cas k < j < i Mkj n'existe pas, on prend t{TMj} = -Mkk-1Mkj
1941 if (k<j)
1942 get(i,j)->gemm('N', 'T', Constants<T>::pone, get(i,k), TM[j], Constants<T>::pone);
1943 // cas j < k < i Toutes les matrices existent sous la diag
1944 else if (k<i)
1945 get(i,j)->gemm('N', 'N', Constants<T>::mone, get(i,k), get(k,j), Constants<T>::pone);
1946 // cas j < i < k Mik n'existe pas, on prend TM[i] = Mki
1947 else
1948 get(i,j)->gemm('T', 'N', Constants<T>::mone, TM[i], get(k,j), Constants<T>::pone);
1949 }
1950 // Update column 'k' = right-multiplied by -M_kk-1
1951 for (int i=0 ; i<nrChildRow() ; i++)
1952 if (i>k) {
1953 // Mik <- - Mik Mkk^-1
1954 get(i,k)->copy(TM[i]);
1955 }
1956 for (int j=0 ; j<nrChildCol() ; j++) {
1957 delete TM[j];
1958 TM[j]=NULL;
1959 }
1960 }
1961 } else {
1962 this->recursiveInverseNosym();
1963 }
1964 }
1965 }
1966
1967 template<typename T>
solveLowerTriangularLeft(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo,MainOp) const1968 void HMatrix<T>::solveLowerTriangularLeft(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo, MainOp) const {
1969 DECLARE_CONTEXT;
1970 if (isVoid()) return;
1971 // At first, the recursion one (simple case)
1972 if (!this->isLeaf() && !b->isLeaf()) {
1973 this->recursiveSolveLowerTriangularLeft(b, algo, diag, uplo);
1974 } else if(!b->isLeaf()) {
1975 // B isn't a leaf, then 'this' is one
1976 assert(this->isLeaf());
1977 // Evaluate B as a full matrix, solve, and restore in the matrix
1978 // TODO: check if it's not too bad
1979 FullMatrix<T> bFull(b->rows(), b->cols());
1980 b->evalPart(&bFull, b->rows(), b->cols());
1981 this->solveLowerTriangularLeft(&bFull, algo, diag, uplo);
1982 b->clear();
1983 b->axpy(Constants<T>::pone, &bFull);
1984 } else if(b->isNull()) {
1985 // nothing to do
1986 } else {
1987 if (b->isFullMatrix()) {
1988 this->solveLowerTriangularLeft(b->full(), algo, diag, uplo);
1989 } else {
1990 assert(b->isRkMatrix());
1991 HMatrix<T> * bSubset = b->subset(uplo == Uplo::LOWER ? this->cols() : this->rows(), b->cols());
1992 this->solveLowerTriangularLeft(bSubset->rk()->a, algo, diag, uplo);
1993 if(bSubset != b)
1994 delete bSubset;
1995 }
1996 }
1997 }
1998
1999 template<typename T>
solveLowerTriangularLeft(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2000 void HMatrix<T>::solveLowerTriangularLeft(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2001 DECLARE_CONTEXT;
2002 assert(*rows() == *cols());
2003 assert(cols()->size() == b->rows);
2004 if (isVoid()) return;
2005 if (this->isLeaf()) {
2006 assert(this->isFullMatrix());
2007 full()->solveLowerTriangularLeft(b, algo, diag, uplo);
2008 } else {
2009 // Forward substitution:
2010 // [ L11 | 0 ] [ X1 ] [ b1 ]
2011 // [ ----+---- ] * [----] = [ -- ]
2012 // [ L21 | L22 ] [ X2 ] [ b2 ]
2013 //
2014 // L11 * X1 = b1 (by recursive forward substitution)
2015 // L21 * X1 + L22 * X2 = b2 (forward substitution of L22*X2=b2-L21*X1)
2016 //
2017
2018 int offset(0);
2019 vector<ScalarArray<T> > sub;
2020 for (int i=0 ; i<nrChildRow() ; i++) {
2021 // Create sub[i] = a ScalarArray (without copy of data) for the rows in front of the i-th matrix block
2022 sub.push_back(ScalarArray<T>(*b, offset, get(i, i)->cols()->size(), 0, b->cols));
2023 offset += get(i, i)->cols()->size();
2024 // Update sub[i] with the contribution of the solutions already computed sub[j] j<i
2025 for (int j=0 ; j<i ; j++) {
2026 const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2027 if (u_ji)
2028 u_ji->gemv(uplo == Uplo::LOWER ? 'N' : 'T', Constants<T>::mone, &sub[j], Constants<T>::pone, &sub[i]);
2029 }
2030 // Solve the i-th diagonal system
2031 get(i, i)->solveLowerTriangularLeft(&sub[i], algo, diag, uplo);
2032 }
2033 }
2034 }
2035
2036 template<typename T>
solveLowerTriangularLeft(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2037 void HMatrix<T>::solveLowerTriangularLeft(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2038 solveLowerTriangularLeft(&b->data, algo, diag, uplo);
2039 }
2040
2041 template<typename T>
solveUpperTriangularRight(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2042 void HMatrix<T>::solveUpperTriangularRight(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2043 DECLARE_CONTEXT;
2044 if (rows()->size() == 0 || cols()->size() == 0) return;
2045 // The recursion one (simple case)
2046 if (!this->isLeaf() && !b->isLeaf()) {
2047 this->recursiveSolveUpperTriangularRight(b, algo, diag, uplo);
2048 } else if(!b->isLeaf()) {
2049 // B isn't a leaf, then 'this' is one
2050 assert(this->isLeaf());
2051 assert(isFullMatrix());
2052 // Evaluate B, solve by column and restore all in the matrix
2053 // TODO: check if it's not too bad
2054 FullMatrix<T> bFull(b->rows(), b->cols());
2055 b->evalPart(&bFull, b->rows(), b->cols());
2056 this->solveUpperTriangularRight(&bFull, algo, diag, uplo);
2057 b->clear();
2058 b->axpy(Constants<T>::pone, &bFull);
2059 } else if(b->isNull()) {
2060 // nothing to do
2061 } else {
2062 if (b->isFullMatrix()) {
2063 this->solveUpperTriangularRight(b->full(), algo, diag, uplo);
2064 } else {
2065 assert(b->isRkMatrix());
2066 // Xa Xb^t U = Ba Bb^t
2067 // - Xa = Ba
2068 // - Xb^t U = Bb^t
2069 // Xb is stored without being transposed, thus we solve
2070 // U^t Xb = Bb instead
2071 HMatrix<T> * tmp = b->subset(b->rows(), uplo == Uplo::LOWER ? this->cols() : this->rows());
2072 this->solveLowerTriangularLeft(tmp->rk()->b, algo, diag, uplo);
2073 if(tmp != b)
2074 delete tmp;
2075 }
2076 }
2077 }
2078
2079 template<typename T>
solveUpperTriangularRight(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2080 void HMatrix<T>::solveUpperTriangularRight(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2081 DECLARE_CONTEXT;
2082 assert(*rows() == *cols());
2083 assert(rows()->size() == b->cols);
2084 if (isVoid()) return;
2085 if (this->isLeaf()) {
2086 assert(this->isFullMatrix());
2087 full()->solveUpperTriangularRight(b, algo, diag, uplo);
2088 } else {
2089 // Forward substitution:
2090 // [ U11 | U12 ]
2091 // [ X1 | X2 ] * [ ----+---- ] = [ b1 | b2 ]
2092 // [ 0 | U22 ]
2093 //
2094 // X1 * U11 = b1 (by recursive forward substitution)
2095 // X1 * U12 + X2 * U22 = b2 (forward substitution of X2*U22=b2-X1*U12)
2096 //
2097
2098 int offset(0);
2099 vector<ScalarArray<T> > sub;
2100 for (int i=0 ; i<nrChildCol() ; i++) {
2101 // Create sub[i] = a ScalarArray (without copy of data) for the columns in front of the i-th matrix block
2102 sub.push_back(ScalarArray<T>(*b, 0, b->rows, offset, get(i, i)->rows()->size()));
2103 offset += get(i, i)->rows()->size();
2104 // Update sub[i] with the contribution of the solutions already computed sub[j]
2105 for (int j=0 ; j<i ; j++) {
2106 const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2107 if (u_ji)
2108 u_ji->gemv(uplo == Uplo::LOWER ? 'T' : 'N', Constants<T>::mone, &sub[j], Constants<T>::pone, &sub[i], Side::RIGHT);
2109 }
2110 // Solve the i-th diagonal system
2111 get(i, i)->solveUpperTriangularRight(&sub[i], algo, diag, uplo);
2112 }
2113 }
2114 }
2115
2116 template<typename T>
solveUpperTriangularRight(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2117 void HMatrix<T>::solveUpperTriangularRight(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2118 solveUpperTriangularRight(&b->data, algo, diag, uplo);
2119 }
2120
2121 /* Resolve U.X=B, solution saved in B, with B Hmat
2122 Only called by luDecomposition
2123 */
2124 template<typename T>
solveUpperTriangularLeft(HMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo,MainOp) const2125 void HMatrix<T>::solveUpperTriangularLeft(HMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo, MainOp) const {
2126 DECLARE_CONTEXT;
2127 if (rows()->size() == 0 || cols()->size() == 0) return;
2128 // At first, the recursion one (simple case)
2129 if (!this->isLeaf() && !b->isLeaf()) {
2130 this->recursiveSolveUpperTriangularLeft(b, algo, diag, uplo);
2131 } else if(!b->isLeaf()) {
2132 // B isn't a leaf, then 'this' is one
2133 assert(this->isLeaf());
2134 // Evaluate B, solve by column, and restore in the matrix
2135 // TODO: check if it's not too bad
2136 FullMatrix<T> bFull(b->rows(), b->cols());
2137 b->evalPart(&bFull, b->rows(), b->cols());
2138 this->solveUpperTriangularLeft(&bFull, algo, diag, uplo);
2139 b->clear();
2140 b->axpy(Constants<T>::pone, &bFull);
2141 } else if(b->isNull()) {
2142 // nothing to do
2143 } else {
2144 if (b->isFullMatrix()) {
2145 this->solveUpperTriangularLeft(b->full(), algo, diag, uplo);
2146 } else {
2147 assert(b->isRkMatrix());
2148 HMatrix * bSubset = b->subset(uplo == Uplo::LOWER ? this->rows() : this->cols(), b->cols());
2149 this->solveUpperTriangularLeft(bSubset->rk()->a, algo, diag, uplo);
2150 if(bSubset != b)
2151 delete bSubset;
2152 }
2153 }
2154 }
2155
2156 template<typename T>
solveUpperTriangularLeft(ScalarArray<T> * b,Factorization algo,Diag diag,Uplo uplo) const2157 void HMatrix<T>::solveUpperTriangularLeft(ScalarArray<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2158 DECLARE_CONTEXT;
2159 assert(*rows() == *cols());
2160 assert(rows()->size() == b->rows || uplo == Uplo::UPPER);
2161 assert(cols()->size() == b->rows || uplo == Uplo::LOWER);
2162 if (rows()->size() == 0 || cols()->size() == 0) return;
2163 if (this->isLeaf()) {
2164 full()->solveUpperTriangularLeft(b, algo, diag, uplo);
2165 } else {
2166 // Backward substitution:
2167 // [ U11 | U12 ] [ X1 ] [ b1 ]
2168 // [ ----+---- ] * [----] = [ -- ]
2169 // [ 0 | U22 ] [ X2 ] [ b2 ]
2170 //
2171 // U22 * X2 = b12(by recursive backward substitution)
2172 // U11 * X1 + U12 * X2 = b1 (backward substitution of U11*X1=b1-U12*X2)
2173 //
2174
2175 int offset(0);
2176 vector<ScalarArray<T> > sub;
2177 for (int i=0 ; i<nrChildRow() ; i++) {
2178 // Create sub[i] = a ScalarArray (without copy of data) for the rows in front of the i-th matrix block
2179 sub.push_back(b->rowsSubset(offset, get(i, i)->cols()->size()));
2180 offset += get(i, i)->cols()->size();
2181 }
2182 for (int i=nrChildRow()-1 ; i>=0 ; i--) {
2183 // Solve the i-th diagonal system
2184 get(i, i)->solveUpperTriangularLeft(&sub[i], algo, diag, uplo);
2185 // Update sub[j] j<i with the contribution of the solutions just computed sub[i]
2186 for (int j=0 ; j<i ; j++) {
2187 const HMatrix<T>* u_ji = (uplo == Uplo::LOWER ? get(i, j) : get(j, i));
2188 if (u_ji)
2189 u_ji->gemv(uplo == Uplo::LOWER ? 'T' : 'N', Constants<T>::mone, &sub[i], Constants<T>::pone, &sub[j]);
2190 }
2191 }
2192 }
2193 }
2194
2195 template<typename T>
solveUpperTriangularLeft(FullMatrix<T> * b,Factorization algo,Diag diag,Uplo uplo) const2196 void HMatrix<T>::solveUpperTriangularLeft(FullMatrix<T>* b, Factorization algo, Diag diag, Uplo uplo) const {
2197 solveUpperTriangularLeft(&b->data, algo, diag, uplo);
2198 }
2199
lltDecomposition(hmat_progress_t * progress)2200 template<typename T> void HMatrix<T>::lltDecomposition(hmat_progress_t * progress) {
2201
2202 assertLower(this);
2203 if (isVoid()) {
2204 // nothing to do
2205 } else if(this->isLeaf()) {
2206 full()->lltDecomposition();
2207 if(progress != NULL) {
2208 progress->current= rows()->offset() + rows()->size();
2209 progress->update(progress);
2210 }
2211 } else {
2212 HMAT_ASSERT(isLower);
2213 this->recursiveLltDecomposition(progress);
2214 }
2215 isTriLower = true;
2216 isLower = false;
2217 }
2218
2219 template<typename T>
luDecomposition(hmat_progress_t * progress)2220 void HMatrix<T>::luDecomposition(hmat_progress_t * progress) {
2221 DECLARE_CONTEXT;
2222
2223 if (rows()->size() == 0 || cols()->size() == 0) return;
2224 if (this->isLeaf()) {
2225 assert(isFullMatrix());
2226 full()->luDecomposition();
2227 full()->checkNan();
2228 if(progress != NULL) {
2229 progress->current= rows()->offset() + rows()->size();
2230 progress->update(progress);
2231 }
2232 } else {
2233 this->recursiveLuDecomposition(progress);
2234 }
2235 }
2236
2237 template<typename T>
mdntProduct(const HMatrix<T> * m,const HMatrix<T> * d,const HMatrix<T> * n)2238 void HMatrix<T>::mdntProduct(const HMatrix<T>* m, const HMatrix<T>* d, const HMatrix<T>* n) {
2239 DECLARE_CONTEXT;
2240
2241 HMatrix<T>* x = m->copy();
2242 x->multiplyWithDiag(d); // x=M.D
2243 this->gemm('N', 'T', Constants<T>::mone, x, n, Constants<T>::pone); // this -= M.D.tN
2244 delete x;
2245 }
2246
2247 template<typename T>
mdmtProduct(const HMatrix<T> * m,const HMatrix<T> * d)2248 void HMatrix<T>::mdmtProduct(const HMatrix<T>* m, const HMatrix<T>* d) {
2249 DECLARE_CONTEXT;
2250 if (isVoid() || d->isVoid() || m->isVoid()) return;
2251 // this <- this - M * D * M^T
2252 //
2253 // D is stored separately in full matrix of diagonal leaves (see full_matrix.hpp).
2254 // this is symmetric and stored as lower triangular.
2255 // Warning: d must be the result of an ldlt factorization
2256 assertLower(this);
2257 assert(*d->rows() == *d->cols()); // D is square
2258 assert(*this->rows() == *this->cols()); // this is square
2259 assert(*m->cols() == *d->rows()); // Check if we can have the produit M*D and D*M^T
2260 assert(*this->rows() == *m->rows());
2261
2262 if(!this->isLeaf()) {
2263 if (!m->isLeaf()) {
2264 this->recursiveMdmtProduct(m, d);
2265 } else if (m->isRkMatrix() && !m->isNull()) {
2266 HMatrix<T>* m_copy = m->copy();
2267
2268 assert(*m->cols() == *d->rows());
2269 assert(*m_copy->rk()->cols == *d->rows());
2270 m_copy->multiplyWithDiag(d); // right multiplication by D
2271 RkMatrix<T>* rkMat = RkMatrix<T>::multiplyRkRk('N', 'T', m_copy->rk(), m->rk(), m->lowRankEpsilon());
2272 delete m_copy;
2273
2274 this->axpy(Constants<T>::mone, rkMat);
2275 delete rkMat;
2276 } else if(m->isFullMatrix()){
2277 HMatrix<T>* copy_m = m->copy();
2278 HMAT_ASSERT(copy_m);
2279 copy_m->multiplyWithDiag(d); // right multiplication by D
2280
2281 FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix('N', 'T', copy_m, m);
2282 HMAT_ASSERT(fullMat);
2283 delete copy_m;
2284
2285 this->axpy(Constants<T>::mone, fullMat);
2286 delete fullMat;
2287 } else {
2288 // m is a null matrix (either Rk or Full) so nothing to do.
2289 }
2290 } else {
2291 assert(isFullMatrix());
2292 if (m->isRkMatrix()) {
2293 // this : full
2294 // m : rk
2295 // Strategy: compute mdm^T as FullMatrix and then do this<-this - mdm^T
2296
2297 // 1) copy m = AB^T : m_copy
2298 // 2) m_copy <- m_copy * D (multiplyWithDiag)
2299 // 3) rkMat <- multiplyRkRk ( m_copy , m^T)
2300 // 4) fullMat <- evaluation as a FullMatrix of the product rkMat = (A*(D*B)^T) * (A*B^T)^T
2301 // 5) this <- this - fullMat
2302 if (!m->isNull()) {
2303 HMatrix<T>* m_copy = m->copy();
2304 m_copy->multiplyWithDiag(d);
2305
2306 RkMatrix<T>* rkMat = RkMatrix<T>::multiplyRkRk('N', 'T', m_copy->rk(), m->rk(), m->lowRankEpsilon());
2307 FullMatrix<T>* fullMat = rkMat->eval();
2308 delete m_copy;
2309 delete rkMat;
2310 full()->axpy(Constants<T>::mone, fullMat);
2311 delete fullMat;
2312 }
2313 } else if (m->isFullMatrix()) {
2314 // S <- S - M*D*M^T
2315 assert(!full()->isTriUpper());
2316 assert(!full()->isTriLower());
2317 assert(!m->full()->isTriUpper());
2318 assert(!m->full()->isTriLower());
2319 FullMatrix<T> mTmp(m->rows(), m->cols());
2320 mTmp.copyMatrixAtOffset(m->full(), 0, 0);
2321 if (d->isFullMatrix()) {
2322 mTmp.multiplyWithDiagOrDiagInv(d->full()->diagonal, false, Side::RIGHT);
2323 } else {
2324 Vector<T> diag(d->cols()->size());
2325 d->extractDiagonal(diag.ptr());
2326 mTmp.multiplyWithDiagOrDiagInv(&diag, false, Side::RIGHT);
2327 }
2328 full()->gemm('N', 'T', Constants<T>::mone, &mTmp, m->full(), Constants<T>::pone);
2329 } else if (!m->isLeaf()){
2330 FullMatrix<T> mTmp(m->rows(), m->cols());
2331 m->evalPart(&mTmp, m->rows(), m->cols());
2332 FullMatrix<T> mTmpCopy(m->rows(), m->cols());
2333 mTmpCopy.copyMatrixAtOffset(&mTmp, 0, 0);
2334 if (d->isFullMatrix()) {
2335 mTmp.multiplyWithDiagOrDiagInv(d->full()->diagonal, false, Side::RIGHT);
2336 } else {
2337 Vector<T> diag(d->cols()->size());
2338 d->extractDiagonal(diag.ptr());
2339 mTmp.multiplyWithDiagOrDiagInv(&diag, false, Side::RIGHT);
2340 }
2341 full()->gemm('N', 'T', Constants<T>::mone, &mTmp, &mTmpCopy, Constants<T>::pone);
2342 }
2343 }
2344 }
2345
assertLdlt(const HMatrix<T> * me)2346 template<typename T> void assertLdlt(const HMatrix<T> * me) {
2347 // Void block (row & col)
2348 if (me->rows()->size() == 0 && me->cols()->size() == 0) return;
2349 #ifdef DEBUG_LDLT
2350 assert(me->isTriLower);
2351 if (me->isLeaf()) {
2352 assert(me->isFullMatrix());
2353 assert(me->full()->diagonal);
2354 } else {
2355 for (int i=0 ; i<me->nrChildRow() ; i++)
2356 assertLdlt(me->get(i,i));
2357 }
2358 #else
2359 ignore_unused_arg(me);
2360 #endif
2361 }
2362
assertLower(const HMatrix<T> * me)2363 template<typename T> void assertLower(const HMatrix<T> * me) {
2364 #ifdef DEBUG_LDLT
2365 if (me->isLeaf()) {
2366 return;
2367 } else {
2368 assert(me->isLower);
2369 for (int i=0 ; i<me->nrChildRow() ; i++)
2370 for (int j=0 ; j<me->nrChildCol() ; j++) {
2371 if (i<j) /* NULL above diag */
2372 assert(!me->get(i,j));
2373 if (i==j) /* Lower on diag */
2374 assertLower(me->get(i,i));
2375 }
2376 }
2377 #else
2378 ignore_unused_arg(me);
2379 #endif
2380 }
2381
assertUpper(const HMatrix<T> * me)2382 template<typename T> void assertUpper(const HMatrix<T> * me) {
2383 #ifdef DEBUG_LDLT
2384 if (me->isLeaf()) {
2385 return;
2386 } else {
2387 assert(me->isUpper);
2388 for (int i=0 ; i<me->nrChildRow() ; i++)
2389 for (int j=0 ; j<me->nrChildCol() ; j++) {
2390 if (i==j) /* Upper on diag */
2391 assertUpper(me->get(i,i));
2392 }
2393 }
2394 #else
2395 ignore_unused_arg(me);
2396 #endif
2397 }
2398
2399 template<typename T>
ldltDecomposition(hmat_progress_t * progress)2400 void HMatrix<T>::ldltDecomposition(hmat_progress_t * progress) {
2401 DECLARE_CONTEXT;
2402 assertLower(this);
2403
2404 if (isVoid()) {
2405 // nothing to do
2406 } else if (this->isLeaf()) {
2407 //The basic case of the recursion is necessarily a full matrix leaf
2408 //since the recursion is done with *rows() == *cols().
2409
2410 assert(isFullMatrix());
2411 full()->ldltDecomposition();
2412 if(progress != NULL) {
2413 progress->current= rows()->offset() + rows()->size();
2414 progress->update(progress);
2415 }
2416 assert(full()->diagonal);
2417 } else {
2418 this->recursiveLdltDecomposition(progress);
2419 }
2420 isTriLower = true;
2421 isLower = false;
2422 }
2423
2424 template<typename T>
solve(ScalarArray<T> * b) const2425 void HMatrix<T>::solve(ScalarArray<T>* b) const {
2426 DECLARE_CONTEXT;
2427 // Solve (LU) X = b
2428 // First compute L Y = b
2429 this->solveLowerTriangularLeft(b, Factorization::LU, Diag::UNIT, Uplo::LOWER);
2430 // Then compute U X = Y
2431 this->solveUpperTriangularLeft(b, Factorization::LU, Diag::NONUNIT, Uplo::UPPER);
2432 }
2433
2434 template<typename T>
solve(FullMatrix<T> * b) const2435 void HMatrix<T>::solve(FullMatrix<T>* b) const {
2436 solve(&b->data);
2437 }
2438
2439 template<typename T>
trsm(char side,char uplo,char trans,char diag,T alpha,HMatrix<T> * B) const2440 void HMatrix<T>::trsm( char side, char uplo, char trans, char diag,
2441 T alpha, HMatrix<T>* B ) const {
2442
2443 bool upper = (uplo == 'u') || (uplo == 'U');
2444 bool left = (side == 'l') || (side == 'L');
2445 Diag unit = (diag == 'u' || diag == 'U') ? Diag::UNIT : Diag::NONUNIT;
2446 bool notrans = (trans == 'n') || (trans == 'N');
2447
2448 /* Upper case */
2449 if ( upper ) {
2450 if ( left ) {
2451 if ( notrans ) {
2452 /* LUN */
2453 solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::UPPER );
2454 }
2455 else {
2456 /* LUT */
2457 HMAT_ASSERT_MSG( 0, "ERROR: TRSM LUT case is for now missing !!!" );
2458 }
2459 }
2460 else {
2461 if ( notrans ) {
2462 /* RUN */
2463 solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::UPPER );
2464 }
2465 else {
2466 /* RUT */
2467 HMAT_ASSERT_MSG( 0, "ERROR: TRSM RUT case is for now missing !!!" );
2468 }
2469 }
2470 }
2471 else {
2472 if ( left ) {
2473 if ( notrans ) {
2474 /* LLN */
2475 solveLowerTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2476 }
2477 else {
2478 /* LLT */
2479 solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2480 }
2481 }
2482 else {
2483 if ( notrans ) {
2484 /* RLN */
2485 HMAT_ASSERT_MSG( 0, "ERROR: TRSM RLN case is for now missing !!!" );
2486 }
2487 else {
2488 /* RLT */
2489 solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::LOWER );
2490 }
2491 }
2492 }
2493 }
2494
2495 template<typename T>
trsm(char side,char uplo,char trans,char diag,T alpha,ScalarArray<T> * B) const2496 void HMatrix<T>::trsm( char side, char uplo, char trans, char diag,
2497 T alpha, ScalarArray<T>* B ) const {
2498
2499 bool upper = (uplo == 'u') || (uplo == 'U');
2500 bool left = (side == 'l') || (side == 'L');
2501 Diag unit = (diag == 'u' || diag == 'U') ? Diag::UNIT : Diag::NONUNIT;
2502 bool notrans = (trans == 'n') || (trans == 'N');
2503
2504 /* Upper case */
2505 if ( upper ) {
2506 if ( left ) {
2507 if ( notrans ) {
2508 /* LUN */
2509 solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::UPPER );
2510 }
2511 else {
2512 /* LUT */
2513 HMAT_ASSERT_MSG( 0, "ERROR: TRSM LUT case is for now missing !!!" );
2514 }
2515 }
2516 else {
2517 if ( notrans ) {
2518 /* RUN */
2519 solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::UPPER );
2520 }
2521 else {
2522 /* RUT */
2523 HMAT_ASSERT_MSG( 0, "ERROR: TRSM RUT case is for now missing !!!" );
2524 }
2525 }
2526 }
2527 else {
2528 if ( left ) {
2529 if ( notrans ) {
2530 /* LLN */
2531 solveLowerTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2532 }
2533 else {
2534 /* LLT */
2535 solveUpperTriangularLeft( B, Factorization::LU, unit, Uplo::LOWER );
2536 }
2537 }
2538 else {
2539 if ( notrans ) {
2540 /* RLN */
2541 HMAT_ASSERT_MSG( 0, "ERROR: TRSM RLN case is for now missing !!!" );
2542 }
2543 else {
2544 /* RLT */
2545 solveUpperTriangularRight( B, Factorization::LU, unit, Uplo::LOWER );
2546 }
2547 }
2548 }
2549 }
2550
2551 template<typename T>
extractDiagonal(T * diag) const2552 void HMatrix<T>::extractDiagonal(T* diag) const {
2553 DECLARE_CONTEXT;
2554 if (rows()->size() == 0 || cols()->size() == 0) return;
2555 if(this->isLeaf()) {
2556 assert(isFullMatrix());
2557 if(full()->diagonal) {
2558 // LDLt
2559 memcpy(diag, full()->diagonal->const_ptr(), full()->rows() * sizeof(T));
2560 } else {
2561 // LLt
2562 for (int i = 0; i < full()->rows(); ++i)
2563 diag[i] = full()->get(i,i);
2564 }
2565 } else {
2566 for (int i=0 ; i<nrChildRow() ; i++) {
2567 get(i,i)->extractDiagonal(diag);
2568 diag += get(i,i)->rows()->size();
2569 }
2570 }
2571 }
2572
2573 /* Solve M.X=B with M hmat LU factorized*/
solve(HMatrix<T> * b,Factorization algo) const2574 template<typename T> void HMatrix<T>::solve(
2575 HMatrix<T>* b,
2576 Factorization algo) const {
2577 DECLARE_CONTEXT;
2578 switch(algo) {
2579 case Factorization::LU:
2580 /* Solve LX=B, result in B */
2581 this->solveLowerTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2582 /* Solve UX=B, result in B */
2583 this->solveUpperTriangularLeft(b, algo, Diag::NONUNIT, Uplo::UPPER);
2584 break;
2585 case Factorization::LDLT:
2586 /* Solve LX=B, result in B */
2587 this->solveLowerTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2588 /* Solve DX=B, result in B */
2589 b->multiplyWithDiag(this, Side::LEFT, true);
2590 /* Solve L^tX=B, result in B */
2591 this->solveUpperTriangularLeft(b, algo, Diag::UNIT, Uplo::LOWER);
2592 break;
2593 case Factorization::LLT:
2594 /* Solve LX=B, result in B */
2595 this->solveLowerTriangularLeft(b, algo, Diag::NONUNIT, Uplo::LOWER);
2596 /* Solve L^tX=B, result in B */
2597 this->solveUpperTriangularLeft(b, algo, Diag::NONUNIT, Uplo::UPPER);
2598 break;
2599 default:
2600 HMAT_ASSERT(false);
2601 }
2602 }
2603
solveDiagonal(ScalarArray<T> * b) const2604 template<typename T> void HMatrix<T>::solveDiagonal(ScalarArray<T>* b) const {
2605 // Solve D*X = B and store result into B
2606 // Diagonal extraction
2607 if (rows()->size() == 0 || cols()->size() == 0) return;
2608 if(isFullMatrix() && full()->diagonal) {
2609 // LDLt
2610 b->multiplyWithDiagOrDiagInv(full()->diagonal, true, Side::LEFT); // multiply to the left by the inverse
2611 } else {
2612 // LLt
2613 Vector<T>* diag = new Vector<T>(cols()->size());
2614 extractDiagonal(diag->ptr());
2615 b->multiplyWithDiagOrDiagInv(diag, true, Side::LEFT); // multiply to the left by the inverse
2616 delete diag;
2617 }
2618 }
2619
solveDiagonal(FullMatrix<T> * b) const2620 template<typename T> void HMatrix<T>::solveDiagonal(FullMatrix<T>* b) const {
2621 solveDiagonal(&b->data);
2622 }
2623
2624 template<typename T>
solveLdlt(ScalarArray<T> * b) const2625 void HMatrix<T>::solveLdlt(ScalarArray<T>* b) const {
2626 DECLARE_CONTEXT;
2627 assertLdlt(this);
2628 // L*D*L^T * X = B
2629 // B <- solution of L * Y = B : Y = D*L^T * X
2630 this->solveLowerTriangularLeft(b, Factorization::LDLT, Diag::UNIT, Uplo::LOWER);
2631
2632 // B <- D^{-1} Y : solution of D*Y = B : Y = L^T * X
2633 this->solveDiagonal(b);
2634
2635 // B <- solution of L^T X = B : the solution X we are looking for is stored in B
2636 this->solveUpperTriangularLeft(b, Factorization::LDLT, Diag::UNIT, Uplo::LOWER);
2637 }
2638
2639 template<typename T>
solveLdlt(FullMatrix<T> * b) const2640 void HMatrix<T>::solveLdlt(FullMatrix<T>* b) const {
2641 solveLdlt(&b->data);
2642 }
2643
2644 template<typename T>
solveLlt(ScalarArray<T> * b) const2645 void HMatrix<T>::solveLlt(ScalarArray<T>* b) const {
2646 DECLARE_CONTEXT;
2647 // L*L^T * X = B
2648 // B <- solution of L * Y = B : Y = L^T * X
2649 this->solveLowerTriangularLeft(b, Factorization::LLT, Diag::NONUNIT, Uplo::LOWER);
2650
2651 // B <- solution of L^T X = B : the solution X we are looking for is stored in B
2652 this->solveUpperTriangularLeft(b, Factorization::LLT, Diag::NONUNIT, Uplo::LOWER);
2653 }
2654
2655 template<typename T>
solveLlt(FullMatrix<T> * b) const2656 void HMatrix<T>::solveLlt(FullMatrix<T>* b) const {
2657 solveLlt(&b->data);
2658 }
2659
2660 template<typename T>
checkStructure() const2661 void HMatrix<T>::checkStructure() const {
2662 #if 0
2663 if (this->isLeaf()) {
2664 return;
2665 }
2666 for (int i = 0; i < this->nrChild(); i++) {
2667 HMatrix<T>* child = this->getChild(i);
2668 if (child) {
2669 assert(child->rows()->isSubset(*(this->rows())) && child->cols()->isSubset(*(this->cols())));
2670 child->checkStructure();
2671 }
2672 }
2673 #endif
2674 }
2675
2676 template<typename T>
checkNan() const2677 void HMatrix<T>::checkNan() const {
2678 #if 0
2679 if (this->isLeaf()) {
2680 if (isFullMatrix()) {
2681 full()->checkNan();
2682 }
2683 if (isRkMatrix()) {
2684 rk()->checkNan();
2685 }
2686 } else {
2687 for (int i = 0; i < this->nrChild(); i++) {
2688 if (this->getChild(i)) {
2689 this->getChild(i)->checkNan();
2690 }
2691 }
2692 }
2693 #endif
2694 }
2695
setTriLower(bool value)2696 template<typename T> void HMatrix<T>::setTriLower(bool value)
2697 {
2698 isTriLower = value;
2699 if(!this->isLeaf())
2700 {
2701 for (int i = 0; i < nrChildRow(); i++)
2702 get(i, i)->setTriLower(value);
2703 }
2704 }
2705
setLower(bool value)2706 template<typename T> void HMatrix<T>::setLower(bool value)
2707 {
2708 isLower = value;
2709 if(!this->isLeaf())
2710 {
2711 for (int i = 0; i < nrChildRow(); i++)
2712 get(i, i)->setLower(value);
2713 }
2714 }
2715
rk(const ScalarArray<T> * a,const ScalarArray<T> * b)2716 template<typename T> void HMatrix<T>::rk(const ScalarArray<T> * a, const ScalarArray<T> * b) {
2717 if(!isAssembled())
2718 rk(NULL);
2719 assert(isRkMatrix());
2720 if(a == NULL && isNull())
2721 return;
2722 delete rk_;
2723 rk(new RkMatrix<T>(a == NULL ? NULL : a->copy(), rows(),
2724 b == NULL ? NULL : b->copy(), cols()));
2725 }
2726
listAllLeaves(std::deque<const HMatrix<T> * > & out) const2727 template<typename T> void HMatrix<T>::listAllLeaves(std::deque<const HMatrix<T> *> & out) const {
2728 std::vector<const HMatrix<T> *> stack;
2729 stack.push_back(this);
2730 while(!stack.empty()) {
2731 const HMatrix<T> * m = stack.back();
2732 stack.pop_back();
2733 if(m->isLeaf()) {
2734 out.push_back(m);
2735 } else {
2736 for(int i = 0; i < m->nrChild(); i++) {
2737 if(m->getChild(i) != nullptr)
2738 stack.push_back(m->getChild(i));
2739 }
2740 }
2741 }
2742 }
2743
2744 // No way to avoid copy/past of the const version
listAllLeaves(std::deque<HMatrix<T> * > & out)2745 template<typename T> void HMatrix<T>::listAllLeaves(std::deque<HMatrix<T> *> & out) {
2746 std::vector<HMatrix<T> *> stack;
2747 stack.push_back(this);
2748 while(!stack.empty()) {
2749 HMatrix<T> * m = stack.back();
2750 stack.pop_back();
2751 if(m->isLeaf()) {
2752 out.push_back(m);
2753 } else {
2754 for(int i = 0; i < m->nrChild(); i++) {
2755 if(m->getChild(i) != nullptr)
2756 stack.push_back(m->getChild(i));
2757 }
2758 }
2759 }
2760 }
2761
toString() const2762 template<typename T> std::string HMatrix<T>::toString() const {
2763 std::deque<const HMatrix<T> *> leaves;
2764 this->listAllLeaves(leaves);
2765 int nbAssembled = 0;
2766 int nbNullFull = 0;
2767 int nbNullRk = 0;
2768 double diagNorm = 0;
2769 for(unsigned int i = 0; i < leaves.size(); i++) {
2770 const HMatrix<T> * l = leaves[i];
2771 if(l->isAssembled()) {
2772 nbAssembled++;
2773 if(l->isNull()) {
2774 if(l->isRkMatrix())
2775 nbNullRk++;
2776 else
2777 nbNullFull++;
2778 }
2779 else if(l->isFullMatrix() && l->full()->diagonal) {
2780 diagNorm += l->full()->diagonal->normSqr();
2781 }
2782 }
2783 }
2784 diagNorm = sqrt(diagNorm);
2785 std::stringstream sstm;
2786 sstm << "HMatrix(rows=[" << rows()->offset() << ", " << rows()->size() <<
2787 "], cols=[" << cols()->offset() << ", " << cols()->size() <<
2788 "], pointer=" << (void*)this << ", leaves=" << leaves.size() <<
2789 ", assembled=" << isAssembled() << ", assembledLeaves=" << nbAssembled <<
2790 ", nullFull=" << nbNullFull << ", nullRk=" << nbNullRk <<
2791 ", rank=" << rank_ << ", diagNorm=" << diagNorm << ")";
2792 return sstm.str();
2793 }
2794
2795 template<typename T>
unmarshall(const MatrixSettings * settings,int rank,int approxRank,char bitfield,double epsilon)2796 HMatrix<T> * HMatrix<T>::unmarshall(const MatrixSettings * settings, int rank, int approxRank, char bitfield, double epsilon) {
2797 HMatrix<T> * m = new HMatrix<T>(settings);
2798 m->rank_ = rank;
2799 m->isUpper = (bitfield & 1 << 0 ? true : false);
2800 m->isLower = (bitfield & 1 << 1 ? true : false);
2801 m->isTriUpper = (bitfield & 1 << 2 ? true : false);
2802 m->isTriLower = (bitfield & 1 << 3 ? true : false);
2803 m->keepSameRows = (bitfield & 1 << 4 ? true : false);
2804 m->keepSameCols = (bitfield & 1 << 5 ? true : false);
2805 m->approximateRank_ = approxRank;
2806 m->lowRankEpsilon(epsilon, false);
2807 return m;
2808 }
2809
2810 /** Create a temporary block from a list of children */
2811 template<typename T>
HMatrix(const ClusterTree * rows,const ClusterTree * cols,std::vector<HMatrix * > & _children)2812 HMatrix<T>::HMatrix(const ClusterTree * rows, const ClusterTree * cols,
2813 std::vector<HMatrix*> & _children):
2814 Tree<HMatrix<T> >(NULL, 0), rows_(rows), cols_(cols),
2815 rk_(NULL), rank_(UNINITIALIZED_BLOCK),
2816 approximateRank_(UNINITIALIZED_BLOCK), isUpper(false), isLower(false),
2817 keepSameRows(false), keepSameCols(false), temporary_(true), ownRowsClusterTree_(false),
2818 ownColsClusterTree_(false), localSettings(_children[0]->localSettings.global, -1.0) {
2819 this->children = _children;
2820 }
2821
rank(int rank)2822 template<typename T> void HMatrix<T>::rank(int rank) {
2823 HMAT_ASSERT_MSG(rank_ >= 0, "HMatrix::rank can only be used on Rk blocks");
2824 HMAT_ASSERT_MSG(!rk() || rk()->a == NULL || rk()->rank() == rank,
2825 "HMatrix::rank can only be used on evicted blocks");
2826 rank_ = rank;
2827 }
2828
2829
temporary(bool b)2830 template<typename T> void HMatrix<T>::temporary(bool b) {
2831 temporary_ = b;
2832 for (int i=0; i<this->nrChild(); i++) {
2833 if (this->getChild(i))
2834 this->getChild(i)->temporary(b);
2835 }
2836 }
2837
2838 // Templates declaration
2839 template class HMatrix<S_t>;
2840 template class HMatrix<D_t>;
2841 template class HMatrix<C_t>;
2842 template class HMatrix<Z_t>;
2843
2844 template void reorderVector(ScalarArray<S_t>* v, int* indices, int axis);
2845 template void reorderVector(ScalarArray<D_t>* v, int* indices, int axis);
2846 template void reorderVector(ScalarArray<C_t>* v, int* indices, int axis);
2847 template void reorderVector(ScalarArray<Z_t>* v, int* indices, int axis);
2848
2849 template void restoreVectorOrder(ScalarArray<S_t>* v, int* indices, int axis);
2850 template void restoreVectorOrder(ScalarArray<D_t>* v, int* indices, int axis);
2851 template void restoreVectorOrder(ScalarArray<C_t>* v, int* indices, int axis);
2852 template void restoreVectorOrder(ScalarArray<Z_t>* v, int* indices, int axis);
2853
2854 template unsigned char * compatibilityGridForGEMM(const HMatrix<S_t>* a, Axis axisA, char transA, const HMatrix<S_t>* b, Axis axisB, char transB);
2855 template unsigned char * compatibilityGridForGEMM(const HMatrix<D_t>* a, Axis axisA, char transA, const HMatrix<D_t>* b, Axis axisB, char transB);
2856 template unsigned char * compatibilityGridForGEMM(const HMatrix<C_t>* a, Axis axisA, char transA, const HMatrix<C_t>* b, Axis axisB, char transB);
2857 template unsigned char * compatibilityGridForGEMM(const HMatrix<Z_t>* a, Axis axisA, char transA, const HMatrix<Z_t>* b, Axis axisB, char transB);
2858
2859 } // end namespace hmat
2860
2861 #include "recursion.cpp"
2862
2863 namespace hmat {
2864
2865 // Explicit template instantiation
2866 template class RecursionMatrix<S_t, HMatrix<S_t> >;
2867 template class RecursionMatrix<C_t, HMatrix<C_t> >;
2868 template class RecursionMatrix<D_t, HMatrix<D_t> >;
2869 template class RecursionMatrix<Z_t, HMatrix<Z_t> >;
2870
2871 } // end namespace hmat
2872