1 /*
2 HMat-OSS (HMatrix library, open source software)
3
4 Copyright (C) 2014-2015 Airbus Group SAS
5
6 This program is free software; you can redistribute it and/or
7 modify it under the terms of the GNU General Public License
8 as published by the Free Software Foundation; either version 2
9 of the License, or (at your option) any later version.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19
20 http://github.com/jeromerobert/hmat-oss
21 */
22
23 #include "rk_matrix.hpp"
24 #include "h_matrix.hpp"
25 #include "cluster_tree.hpp"
26 #include <cstring> // memcpy
27 #include <cfloat> // DBL_MAX
28 #include "data_types.hpp"
29 #include "lapack_operations.hpp"
30 #include "blas_overloads.hpp"
31 #include "lapack_overloads.hpp"
32 #include "common/context.hpp"
33 #include "common/my_assert.h"
34 #include "common/timeline.hpp"
35 #include "lapack_exception.hpp"
36
37 namespace hmat {
38
39 /** RkApproximationControl */
40 template<typename T> RkApproximationControl RkMatrix<T>::approx;
41
42 /** RkMatrix */
RkMatrix(ScalarArray<T> * _a,const IndexSet * _rows,ScalarArray<T> * _b,const IndexSet * _cols)43 template<typename T> RkMatrix<T>::RkMatrix(ScalarArray<T>* _a, const IndexSet* _rows,
44 ScalarArray<T>* _b, const IndexSet* _cols)
45 : rows(_rows),
46 cols(_cols),
47 a(_a),
48 b(_b)
49 {
50
51 // We make a special case for empty matrices.
52 if ((!a) && (!b)) {
53 return;
54 }
55 assert(a->rows == rows->size());
56 assert(b->rows == cols->size());
57 }
58
~RkMatrix()59 template<typename T> RkMatrix<T>::~RkMatrix() {
60 clear();
61 }
62
63
evalArray(ScalarArray<T> * result) const64 template<typename T> ScalarArray<T>* RkMatrix<T>::evalArray(ScalarArray<T>* result) const {
65 if(result==NULL)
66 result = new ScalarArray<T>(rows->size(), cols->size());
67 if (rank())
68 result->gemm('N', 'T', Constants<T>::pone, a, b, Constants<T>::zero);
69 else
70 result->clear();
71 return result;
72 }
73
eval() const74 template<typename T> FullMatrix<T>* RkMatrix<T>::eval() const {
75 FullMatrix<T>* result = new FullMatrix<T>(rows, cols, false);
76 evalArray(&result->data);
77 return result;
78 }
79
80 // Compute squared Frobenius norm
normSqr() const81 template<typename T> double RkMatrix<T>::normSqr() const {
82 return a->norm_abt_Sqr(*b);
83 }
84
scale(T alpha)85 template<typename T> void RkMatrix<T>::scale(T alpha) {
86 // We need just to scale the first matrix, A.
87 if (a) {
88 a->scale(alpha);
89 }
90 }
91
transpose()92 template<typename T> void RkMatrix<T>::transpose() {
93 std::swap(a, b);
94 std::swap(rows, cols);
95 }
96
clear()97 template<typename T> void RkMatrix<T>::clear() {
98 delete a;
99 delete b;
100 a = NULL;
101 b = NULL;
102 }
103
104 template<typename T>
gemv(char trans,T alpha,const ScalarArray<T> * x,T beta,ScalarArray<T> * y,Side side) const105 void RkMatrix<T>::gemv(char trans, T alpha, const ScalarArray<T>* x, T beta, ScalarArray<T>* y, Side side) const {
106 if (rank() == 0) {
107 if (beta != Constants<T>::pone) {
108 y->scale(beta);
109 }
110 return;
111 }
112 if (side == Side::LEFT) {
113 if (trans == 'N') {
114 // Compute Y <- Y + alpha * A * B^T * X
115 ScalarArray<T> z(b->cols, x->cols);
116 z.gemm('T', 'N', Constants<T>::pone, b, x, Constants<T>::zero);
117 y->gemm('N', 'N', alpha, a, &z, beta);
118 } else if (trans == 'T') {
119 // Compute Y <- Y + alpha * B * A^T * X
120 ScalarArray<T> z(a->cols, x->cols);
121 z.gemm('T', 'N', Constants<T>::pone, a, x, Constants<T>::zero);
122 y->gemm('N', 'N', alpha, b, &z, beta);
123 } else {
124 assert(trans == 'C');
125 // Compute Y <- Y + alpha * (A*B^T)^H * X = Y + alpha * conj(B) * A^H * X
126 ScalarArray<T> z(a->cols, x->cols);
127 z.gemm('C', 'N', Constants<T>::pone, a, x, Constants<T>::zero);
128 ScalarArray<T> * newB = b->copy();
129 newB->conjugate();
130 y->gemm('N', 'N', alpha, newB, &z, beta);
131 delete newB;
132 }
133 } else {
134 if (trans == 'N') {
135 // Compute Y <- Y + alpha * X * A * B^T
136 ScalarArray<T> z(x->rows, a->cols);
137 z.gemm('N', 'N', Constants<T>::pone, x, a, Constants<T>::zero);
138 y->gemm('N', 'T', alpha, &z, b, beta);
139 } else if (trans == 'T') {
140 // Compute Y <- Y + alpha * X * B * A^T
141 ScalarArray<T> z(x->rows, b->cols);
142 z.gemm('N', 'N', Constants<T>::pone, x, b, Constants<T>::zero);
143 y->gemm('N', 'T', alpha, &z, a, beta);
144 } else {
145 assert(trans == 'C');
146 // Compute Y <- Y + alpha * X * (A*B^T)^H = Y + alpha * X * conj(B) * A^H
147 ScalarArray<T> * newB = b->copy();
148 newB->conjugate();
149 ScalarArray<T> z(x->rows, b->cols);
150 z.gemm('N', 'N', Constants<T>::pone, x, newB, Constants<T>::zero);
151 delete newB;
152 y->gemm('N', 'C', alpha, &z, a, beta);
153 }
154 }
155 }
156
subset(const IndexSet * subRows,const IndexSet * subCols) const157 template<typename T> const RkMatrix<T>* RkMatrix<T>::subset(const IndexSet* subRows,
158 const IndexSet* subCols) const {
159 assert(subRows->isSubset(*rows));
160 assert(subCols->isSubset(*cols));
161 ScalarArray<T>* subA = NULL;
162 ScalarArray<T>* subB = NULL;
163 if(rank() > 0) {
164 // The offset in the matrix, and not in all the indices
165 int rowsOffset = subRows->offset() - rows->offset();
166 int colsOffset = subCols->offset() - cols->offset();
167 subA = new ScalarArray<T>(*a, rowsOffset, subRows->size(), 0, rank());
168 subB = new ScalarArray<T>(*b, colsOffset, subCols->size(), 0, rank());
169 }
170 return new RkMatrix<T>(subA, subRows, subB, subCols);
171 }
172
truncatedSubset(const IndexSet * subRows,const IndexSet * subCols,double epsilon) const173 template<typename T> RkMatrix<T>* RkMatrix<T>::truncatedSubset(const IndexSet* subRows,
174 const IndexSet* subCols,
175 double epsilon) const {
176 assert(subRows->isSubset(*rows));
177 assert(subCols->isSubset(*cols));
178 RkMatrix<T> * r = new RkMatrix<T>(NULL, subRows, NULL, subCols);
179 if(rank() > 0) {
180 r->a = ScalarArray<T>(*a, subRows->offset() - rows->offset(),
181 subRows->size(), 0, rank()).copy();
182 r->b = ScalarArray<T>(*b, subCols->offset() - cols->offset(),
183 subCols->size(), 0, rank()).copy();
184 if(epsilon >= 0)
185 r->truncate(epsilon);
186 }
187 return r;
188 }
189
compressedSize()190 template<typename T> size_t RkMatrix<T>::compressedSize() {
191 return ((size_t)rows->size()) * rank() + ((size_t)cols->size()) * rank();
192 }
193
uncompressedSize()194 template<typename T> size_t RkMatrix<T>::uncompressedSize() {
195 return ((size_t)rows->size()) * cols->size();
196 }
197
addRand(double epsilon)198 template<typename T> void RkMatrix<T>::addRand(double epsilon) {
199 DECLARE_CONTEXT;
200 a->addRand(epsilon);
201 b->addRand(epsilon);
202 return;
203 }
204
205 /**
206 * @brief Truncate the A or B block of a RkMatrix.
207 * This is a utilitary function for RkMatrix<T>::truncated and RkMatrix<T>::truncate
208 * @param ab The A or B block to truncate
209 * @param indexSet The index set of the A or B block to truncate (rows for A and cols for B)
210 * @param newK The rank to truncate to (output of the SVD)
211 * @param uv The U or V matrix (output of the SVD)
212 * @return the truncated A or B block
213 */
214 template <typename T>
truncatedAB(ScalarArray<T> * ab,const IndexSet * indexSet,int newK,ScalarArray<T> * uv,bool useInitPivot=false,int initialPivot=0)215 ScalarArray<T> *truncatedAB(ScalarArray<T> *ab, const IndexSet *indexSet,
216 int newK, ScalarArray<T> *uv,
217 bool useInitPivot = false, int initialPivot = 0) {
218 // We need to calculate Qa * u
219 ScalarArray<T>* newAB = new ScalarArray<T>(indexSet->size(), newK);
220 if (useInitPivot && initialPivot) {
221 // If there is an initial pivot, we must compute the product by Q in two parts
222 // first the column >= initialPivotA, obtained from lapack GETRF, will overwrite newA when calling UNMQR
223 // then the first initialPivotA columns, with a classical GEMM, will add the result in newA
224
225 // create subset of a (columns>=initialPivotA) and u (rows>=initialPivotA)
226 ScalarArray<T> sub_ab(*ab, 0, ab->rows, initialPivot, ab->cols-initialPivot);
227 ScalarArray<T> sub_uv(* uv, initialPivot, uv->rows-initialPivot, 0, uv->cols);
228 newAB->copyMatrixAtOffset(&sub_uv, 0, 0);
229 // newA <- Qa * newA (with newA = u)
230 sub_ab.productQ('L', 'N', newAB);
231
232 // then add the regular part of the product by Q
233 ScalarArray<T> sub_ab2(*ab, 0, ab->rows, 0, initialPivot);
234 ScalarArray<T> sub_uv2(* uv, 0, initialPivot, 0, uv->cols);
235 newAB->gemm('N', 'N', Constants<T>::pone, &sub_ab2, &sub_uv2, Constants<T>::pone);
236 } else {
237 // If no initialPivotA, then no gemm, just a productQ()
238 newAB->copyMatrixAtOffset( uv, 0, 0);
239 // newA <- Qa * newA
240 ab->productQ('L', 'N', newAB);
241 }
242
243 newAB->setOrtho( uv->getOrtho());
244 delete uv;
245 return newAB;
246 }
247
truncate(double epsilon,int initialPivotA,int initialPivotB)248 template<typename T> void RkMatrix<T>::truncate(double epsilon, int initialPivotA, int initialPivotB) {
249 DECLARE_CONTEXT;
250
251 if (rank() == 0) {
252 assert(!(a || b));
253 return;
254 }
255
256 assert(rows->size() >= rank());
257 // Case: more columns than one dimension of the matrix.
258 // In this case, the calculation of the SVD of the matrix "R_a R_b^t" is more
259 // expensive than computing the full SVD matrix. We make then a full matrix conversion,
260 // and compress it with RkMatrix::fromMatrix().
261 if (rank() > std::min(rows->size(), cols->size())) {
262 FullMatrix<T>* tmp = eval();
263 RkMatrix<T>* rk = truncatedSvd(tmp, epsilon); // TODO compress with something else than SVD (rank() can still be quite large) ?
264 delete tmp;
265 // "Move" rk into this, and delete the old "this".
266 swap(*rk);
267 delete rk;
268 return;
269 }
270
271 static bool usedRecomp = getenv("HMAT_RECOMPRESS") && strcmp(getenv("HMAT_RECOMPRESS"), "MGS") == 0 ;
272 if (usedRecomp){
273 mGSTruncate(epsilon, initialPivotA, initialPivotB);
274 return;
275 }
276
277 /* To recompress an Rk-matrix to Rk-matrix, we need :
278 - A = Q_a R_A (QR decomposition)
279 - B = Q_b R_b (QR decomposition)
280 - Calculate the SVD of R_a R_b^t = U S V^t
281 - Make truncation U, S, and V in the same way as for
282 compression of a full rank matrix, ie:
283 - Restrict U to its newK first columns U_tilde
284 - Restrict S to its newK first values (diagonal matrix) S_tilde
285 - Restrict V to its newK first columns V_tilde
286 - Put A = Q_a U_tilde SQRT (S_tilde)
287 B = Q_b V_tilde SQRT(S_tilde)
288
289 The sizes of the matrices are:
290 - Qa : rows x k
291 - Ra : k x k
292 - Qb : cols x k
293 - Rb : k x k
294 So:
295 - Ra Rb^t: k x k
296 - U : k * k
297 - S : k (diagonal)
298 - V^t: k * k
299 Hence:
300 - newA: rows x newK
301 - newB: cols x newK
302
303 */
304
305 // truncated SVD of Ra Rb^t (allows failure)
306 ScalarArray<T> *u = NULL, *v = NULL;
307 int newK;
308 // context block to release ra, rb, r ASAP
309 {
310 // QR decomposition of A and B
311 ScalarArray<T> ra(rank(), rank());
312 a->qrDecomposition(&ra, initialPivotA); // A contains Qa and tau_a
313 ScalarArray<T> rb(rank(), rank());
314 b->qrDecomposition(&rb, initialPivotB); // B contains Qb and tau_b
315
316 // R <- Ra Rb^t
317 ScalarArray<T> r(rank(), rank());
318 r.gemm('N','T', Constants<T>::pone, &ra, &rb , Constants<T>::zero);
319
320 // truncated SVD of Ra Rb^t (allows failure)
321 newK = r.truncatedSvdDecomposition(&u, &v, epsilon, true); // TODO use something else than SVD ?
322 }
323 if (newK == 0) {
324 clear();
325 return;
326 }
327 // We need to know if qrDecomposition has used initPivot...
328 // (Not so great, because HMAT_TRUNC_INITPIV is checked at 2 different locations)
329 static char *useInitPivot = getenv("HMAT_TRUNC_INITPIV");
330 ScalarArray<T>* newA = truncatedAB(a, rows, newK, u, useInitPivot, initialPivotA);
331 delete a;
332 a = newA;
333 ScalarArray<T>* newB = truncatedAB(b, cols, newK, v, useInitPivot, initialPivotB);
334 delete b;
335 b = newB;
336 }
337
mGSTruncate(double epsilon,int initialPivotA,int initialPivotB)338 template<typename T> void RkMatrix<T>::mGSTruncate(double epsilon, int initialPivotA, int initialPivotB) {
339 DECLARE_CONTEXT;
340 if (rank() == 0) {
341 assert(!(a || b));
342 return;
343 }
344
345 int kA, kB, newK;
346
347 int krank = rank();
348
349 // Gram-Schmidt on a
350 // On input, a0(m,k)
351 // On output, a(m,kA), ra(kA,k) such that a0 = a * ra
352 ScalarArray<T> ra(krank, krank);
353 kA = a->modifiedGramSchmidt( &ra, epsilon, initialPivotA);
354 if (kA==0) {
355 clear();
356 return;
357 }
358
359 // Gram-Schmidt on b
360 // On input, b0(p,k)
361 // On output, b(p,kB), rb(kB,k) such that b0 = b * rb
362 ScalarArray<T> rb(krank, krank);
363 kB = b->modifiedGramSchmidt( &rb, epsilon, initialPivotB);
364 if (kB==0) {
365 clear();
366 return;
367 }
368
369 // M = a0*b0^T = a*(ra*rb^T)*b^T
370 // We perform an SVD on ra*rb^T:
371 // (ra*rb^T) = U*S*S*Vt
372 // and M = (a*U*S)*(S*Vt*b^T) = (a*U*S)*(b*(S*Vt)^T)^T
373 ScalarArray<T> matR(kA, kB);
374 matR.gemm('N','T', Constants<T>::pone, &ra, &rb , Constants<T>::zero);
375
376 // truncatedSVD (allows failure)
377 ScalarArray<T>* ur = NULL;
378 ScalarArray<T>* vr = NULL;
379 newK = matR.truncatedSvdDecomposition(&ur, &vr, epsilon, true);
380 // On output, ur->rows = kA, vr->rows = kB
381
382 if (newK == 0) {
383 clear();
384 return;
385 }
386
387 /* Multiplication by orthogonal matrix Q: no or/un-mqr as
388 this comes from Gram-Schmidt procedure not Householder
389 */
390 ScalarArray<T> *newA = new ScalarArray<T>(a->rows, newK);
391 newA->gemm('N', 'N', Constants<T>::pone, a, ur, Constants<T>::zero);
392
393 ScalarArray<T> *newB = new ScalarArray<T>(b->rows, newK);
394 newB->gemm('N', 'N', Constants<T>::pone, b, vr, Constants<T>::zero);
395
396 newA->setOrtho(ur->getOrtho());
397 newB->setOrtho(vr->getOrtho());
398 delete ur;
399 delete vr;
400
401 delete a;
402 a = newA;
403 delete b;
404 b = newB;
405 }
406
407 // Swap members with members from another instance.
swap(RkMatrix<T> & other)408 template<typename T> void RkMatrix<T>::swap(RkMatrix<T>& other)
409 {
410 assert(*rows == *other.rows);
411 assert(*cols == *other.cols);
412 std::swap(a, other.a);
413 std::swap(b, other.b);
414 }
415
axpy(double epsilon,T alpha,const FullMatrix<T> * mat)416 template<typename T> void RkMatrix<T>::axpy(double epsilon, T alpha, const FullMatrix<T>* mat) {
417 formattedAddParts(epsilon, &alpha, &mat, 1);
418 }
419
axpy(double epsilon,T alpha,const RkMatrix<T> * mat)420 template<typename T> void RkMatrix<T>::axpy(double epsilon, T alpha, const RkMatrix<T>* mat) {
421 formattedAddParts(epsilon, &alpha, &mat, 1);
422 }
423
424 /*! \brief Try to optimize the order of the Rk matrix to maximize initialPivot
425
426 We re-order the Rk matrices in usedParts[] and the associated constants in usedAlpha[]
427 in order to maximize the number of orthogonal columns starting at column 0 in the A and B
428 panels.
429 \param[in] notNullParts number of elements in usedParts[] and usedAlpha[]
430 \param[in,out] usedParts Array of Rk matrices
431 \param[in,out] usedAlpha Array of constants
432 \param[out] initialPivotA Number of orthogonal columns starting at column 0 in panel A
433 \param[out] initialPivotB Number of orthogonal columns starting at column 0 in panel A
434 */
435 template<typename T>
optimizeRkArray(int notNullParts,const RkMatrix<T> ** usedParts,T * usedAlpha,int & initialPivotA,int & initialPivotB)436 static void optimizeRkArray(int notNullParts, const RkMatrix<T>** usedParts, T *usedAlpha, int &initialPivotA, int &initialPivotB){
437 // 1st optim: Put in first position the Rk matrix with orthogonal panels AND maximum rank
438 int bestRk=-1, bestGain=-1;
439 for (int i=0 ; i<notNullParts ; i++) {
440 // Roughly, the gain from an initial pivot 'p' in a QR factorisation 'm x n' is to reduce the flops
441 // from 2mn^2 to 2m(n^2-p^2), so the gain grows like p^2 for each panel
442 // hence the gain formula : number of orthogonal panels x rank^2
443 int gain = (usedParts[i]->a->getOrtho() + usedParts[i]->b->getOrtho())*usedParts[i]->rank()*usedParts[i]->rank();
444 if (gain > bestGain) {
445 bestGain = gain;
446 bestRk = i;
447 }
448 }
449 if (bestRk > 0) {
450 std::swap(usedParts[0], usedParts[bestRk]) ;
451 std::swap(usedAlpha[0], usedAlpha[bestRk]) ;
452 }
453 initialPivotA = usedParts[0]->a->getOrtho() ? usedParts[0]->rank() : 0;
454 initialPivotB = usedParts[0]->b->getOrtho() ? usedParts[0]->rank() : 0;
455
456 // 2nd optim:
457 // When coallescing Rk from childs toward parent, it is possible to "merge" Rk from (extra-)diagonal
458 // childs because with non-intersecting rows and cols we will extend orthogonality between separate Rk.
459 int best_i1=-1, best_i2=-1, best_rkA=-1, best_rkB=-1;
460 for (int i1=0 ; i1<notNullParts ; i1++)
461 for (int i2=0 ; i2<notNullParts ; i2++)
462 if (i1 != i2) {
463 const RkMatrix<T>* Rk1 = usedParts[i1];
464 const RkMatrix<T>* Rk2 = usedParts[i2];
465 // compute the gain expected from puting Rk1-Rk2 in first position
466 // Orthogonality of Rk2->a is useful only if Rk1->a is ortho AND rows dont intersect (cols for panel b)
467 int rkA = Rk1->a->getOrtho() ? Rk1->rank() + (Rk2->a->getOrtho() && !Rk1->rows->intersects(*Rk2->rows) ? Rk2->rank() : 0) : 0;
468 int rkB = Rk1->b->getOrtho() ? Rk1->rank() + (Rk2->b->getOrtho() && !Rk1->cols->intersects(*Rk2->cols) ? Rk2->rank() : 0) : 0;
469 int gain = rkA*rkA + rkB*rkB ;
470 if (gain > bestGain) {
471 bestGain = gain;
472 best_i1 = i1;
473 best_i2 = i2;
474 best_rkA = rkA;
475 best_rkB = rkB;
476 }
477 }
478
479 if (best_i1 >= 0) {
480 // put i1 in first position, i2 in second
481 std::swap(usedParts[0], usedParts[best_i1]) ;
482 std::swap(usedAlpha[0], usedAlpha[best_i1]) ;
483 if (best_i2==0) best_i2 = best_i1; // handles the case where best_i2 was usedParts[0] which has just been moved
484 std::swap(usedParts[1], usedParts[best_i2]) ;
485 std::swap(usedAlpha[1], usedAlpha[best_i2]) ;
486 initialPivotA = best_rkA;
487 initialPivotB = best_rkB;
488 }
489
490 }
491
allSame(const RkMatrix<T> ** rks,int n)492 template<typename T> bool allSame(const RkMatrix<T>** rks, int n) {
493 for(int i = 1; i < n; i++) {
494 if(!(*rks[0]->rows == *rks[i]->rows) || !(*rks[0]->cols == *rks[i]->cols))
495 return false;
496 }
497 return true;
498 }
499
500 template<typename T>
formattedAddParts(double epsilon,const T * alpha,const RkMatrix<T> * const * parts,const int n,bool hook)501 void RkMatrix<T>::formattedAddParts(double epsilon, const T* alpha, const RkMatrix<T>* const * parts,
502 const int n, bool hook) {
503 if(hook && formatedAddPartsHook && formatedAddPartsHook(this, epsilon, alpha, parts, n))
504 return;
505 // TODO check if formattedAddParts() actually uses sometimes this 'alpha' parameter (or is it always 1 ?)
506 DECLARE_CONTEXT;
507
508 /* List of non-null and non-empty Rk matrices to coalesce, and the corresponding scaling coefficients */
509 const RkMatrix<T>* usedParts[n+1];
510 T usedAlpha[n+1];
511 /* Number of elements in usedParts[] */
512 int notNullParts = 0;
513 /* Sum of the ranks */
514 int rankTotal = 0;
515
516 // If needed, put 'this' in first position in usedParts[]
517 if (rank()) {
518 usedAlpha[0] = Constants<T>::pone ;
519 usedParts[notNullParts++] = this ;
520 rankTotal += rank();
521 }
522
523 for (int i = 0; i < n; i++) {
524 // exclude the NULL and 0-rank matrices
525 if (!parts[i] || parts[i]->rank() == 0 || parts[i]->rows->size() == 0 || parts[i]->cols->size() == 0 || alpha[i]==Constants<T>::zero)
526 continue;
527 // Check that partial RkMatrix indices are subsets of their global indices set.
528 assert(parts[i]->rows->isSubset(*rows));
529 assert(parts[i]->cols->isSubset(*cols));
530 // Add this Rk to the list
531 rankTotal += parts[i]->rank();
532 usedAlpha[notNullParts] = alpha[i] ;
533 usedParts[notNullParts] = parts[i] ;
534 notNullParts++;
535 }
536
537 if(notNullParts == 0)
538 return;
539
540 // In case the sum of the ranks of the sub-matrices is greater than
541 // the matrix size, it is more efficient to put everything in a
542 // full matrix.
543 if (rankTotal >= std::min(rows->size(), cols->size())) {
544 const FullMatrix<T>** fullParts = new const FullMatrix<T>*[notNullParts];
545 fullParts[0] = NULL ;
546 for (int i = rank() ? 1 : 0 ; i < notNullParts; i++) // exclude usedParts[0] if it is 'this'
547 fullParts[i] = usedParts[i]->eval();
548 formattedAddParts(std::abs(epsilon), usedAlpha, fullParts, notNullParts);
549 for (int i = 0; i < notNullParts; i++)
550 delete fullParts[i];
551 delete[] fullParts;
552 return;
553 }
554
555 // Find if the QR factorisation can be accelerated using orthogonality information
556 int initialPivotA = usedParts[0]->a->getOrtho() ? usedParts[0]->rank() : 0;
557 int initialPivotB = usedParts[0]->b->getOrtho() ? usedParts[0]->rank() : 0;
558
559 // Try to optimize the order of the Rk matrix to maximize initialPivot
560 static char *useBestRk = getenv("HMAT_MGS_BESTRK");
561 if (useBestRk)
562 optimizeRkArray(notNullParts, usedParts, usedAlpha, initialPivotA, initialPivotB);
563
564 // According to the indices organization, the sub-matrices are
565 // contiguous blocks in the "big" matrix whose columns offset is
566 // kOffset = usedParts[0]->k + ... + usedParts[i-1]->k
567 // rows offset is
568 // usedParts[i]->rows->offset - rows->offset
569 // rows size
570 // usedParts[i]->rows->size x usedParts[i]->k (rows x columns)
571 // Same for columns.
572
573 // when possible realloc this a & b arrays to limit memory usage and avoid a copy
574 bool useRealloc = usedParts[0] == this && allSame(usedParts, notNullParts);
575 // concatenate a(i) then b(i) to limite memory usage
576 ScalarArray<T>* resultA, *resultB;
577 int rankOffset;
578 if(useRealloc) {
579 resultA = a;
580 rankOffset = a->cols;
581 a->resize(rankTotal);
582 }
583 else {
584 rankOffset = 0;
585 resultA = new ScalarArray<T>(rows->size(), rankTotal);
586 }
587
588 for (int i = useRealloc ? 1 : 0; i < notNullParts; i++) {
589 // Copy 'a' at position rowOffset, kOffset
590 int rowOffset = usedParts[i]->rows->offset() - rows->offset();
591 resultA->copyMatrixAtOffset(usedParts[i]->a, rowOffset, rankOffset);
592 // Scaling the matrix already in place inside resultA
593 if (usedAlpha[i] != Constants<T>::pone) {
594 ScalarArray<T> tmp(*resultA, rowOffset, usedParts[i]->a->rows, rankOffset, usedParts[i]->a->cols);
595 tmp.scale(usedAlpha[i]);
596 }
597 // Update the rank offset
598 rankOffset += usedParts[i]->rank();
599 }
600 assert(rankOffset==rankTotal);
601
602 if(!useRealloc && a != NULL)
603 delete a;
604 a = resultA;
605
606 if(useRealloc) {
607 resultB = b;
608 rankOffset = b->cols;
609 b->resize(rankTotal);
610 }
611 else {
612 rankOffset = 0;
613 resultB = new ScalarArray<T>(cols->size(), rankTotal);
614 }
615
616 for (int i = useRealloc ? 1 : 0; i < notNullParts; i++) {
617 // Copy 'b' at position colOffset, kOffset
618 int colOffset = usedParts[i]->cols->offset() - cols->offset();
619 resultB->copyMatrixAtOffset(usedParts[i]->b, colOffset, rankOffset);
620 // Update the rank offset
621 rankOffset += usedParts[i]->b->cols;
622 }
623
624 if(!useRealloc && b != NULL)
625 delete b;
626 b = resultB;
627
628 assert(rankOffset==rankTotal);
629 // If only one of the parts is non-zero, then the recompression is not necessary
630 if (notNullParts > 1 && epsilon >= 0)
631 truncate(epsilon, initialPivotA, initialPivotB);
632 }
633
634 template<typename T>
formattedAddParts(double epsilon,const T * alpha,const FullMatrix<T> * const * parts,int n)635 void RkMatrix<T>::formattedAddParts(double epsilon, const T* alpha, const FullMatrix<T>* const * parts, int n) {
636 DECLARE_CONTEXT;
637 FullMatrix<T>* me = eval();
638 HMAT_ASSERT(me);
639
640 // TODO: here, we convert Rk->Full, Update the Full with parts[], and Full->Rk. We could also
641 // create a new empty Full, update, convert to Rk and add it to 'this'.
642 // If the parts[] are smaller than 'this', convert them to Rk and add them could be less expensive
643 for (int i = 0; i < n; i++) {
644 if (!parts[i])
645 continue;
646 const IndexSet *rows_full = parts[i]->rows_;
647 const IndexSet *cols_full = parts[i]->cols_;
648 assert(rows_full->isSubset(*rows));
649 assert(cols_full->isSubset(*cols));
650 int rowOffset = rows_full->offset() - rows->offset();
651 int colOffset = cols_full->offset() - cols->offset();
652 int maxCol = cols_full->size();
653 int maxRow = rows_full->size();
654 ScalarArray<T> sub(me->data, rowOffset, maxRow, colOffset, maxCol);
655 sub.axpy(alpha[i], &parts[i]->data);
656 }
657 RkMatrix<T>* result = truncatedSvd(me, epsilon); // TODO compress with something else than SVD
658 delete me;
659 swap(*result);
660 delete result;
661 }
662
663
multiplyRkFull(char transR,char transM,const RkMatrix<T> * rk,const FullMatrix<T> * m)664 template<typename T> RkMatrix<T>* RkMatrix<T>::multiplyRkFull(char transR, char transM,
665 const RkMatrix<T>* rk,
666 const FullMatrix<T>* m) {
667 DECLARE_CONTEXT;
668
669 assert(((transR == 'N') ? rk->cols->size() : rk->rows->size()) == ((transM == 'N') ? m->rows() : m->cols()));
670 const IndexSet *rkRows = ((transR == 'N')? rk->rows : rk->cols);
671 const IndexSet *mCols = ((transM == 'N')? m->cols_ : m->rows_);
672
673 if(rk->rank() == 0) {
674 return new RkMatrix<T>(NULL, rkRows, NULL, mCols);
675 }
676 // If transM is 'N' and transR is 'N', we compute
677 // A * B^T * M ==> newA = A, newB = M^T * B
678 // We can deduce all other cases from this one:
679 // * if transR is 'T', all we have to do is to swap A and B
680 // * if transR is 'C', we swap A and B, but they must also
681 // be conjugate; let us look at the different cases:
682 // + if transM is 'N', newB = M^T * conj(B) = conj(M^H * B)
683 // + if transM is 'T', newB = M * conj(B)
684 // + if transM is 'C', newB = conj(M) * conj(B) = conj(M * B)
685
686 ScalarArray<T> *newA, *newB;
687 ScalarArray<T>* ra = rk->a;
688 ScalarArray<T>* rb = rk->b;
689 if (transR != 'N') {
690 // if transR == 'T', we permute ra and rb; if transR == 'C', they will
691 // also have to be conjugated, but this cannot be done here because rk
692 // is const, this will be performed below.
693 std::swap(ra, rb);
694 }
695 newA = ra->copy();
696 newB = new ScalarArray<T>(transM == 'N' ? m->cols() : m->rows(), rb->cols);
697 if (transR == 'C') {
698 newA->conjugate();
699 if (transM == 'N') {
700 newB->gemm('C', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
701 newB->conjugate();
702 } else if (transM == 'T') {
703 ScalarArray<T> *conjB = rb->copy();
704 conjB->conjugate();
705 newB->gemm('N', 'N', Constants<T>::pone, &m->data, conjB, Constants<T>::zero);
706 delete conjB;
707 } else {
708 assert(transM == 'C');
709 newB->gemm('N', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
710 newB->conjugate();
711 }
712 } else {
713 if (transM == 'N') {
714 newB->gemm('T', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
715 } else if (transM == 'T') {
716 newB->gemm('N', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
717 } else {
718 assert(transM == 'C');
719 ScalarArray<T> *conjB = rb->copy();
720 conjB->conjugate();
721 newB->gemm('N', 'N', Constants<T>::pone, &m->data, conjB, Constants<T>::zero);
722 newB->conjugate();
723 delete conjB;
724 }
725 }
726 RkMatrix<T>* result = new RkMatrix<T>(newA, rkRows, newB, mCols);
727 return result;
728 }
729
730 template<typename T>
multiplyFullRk(char transM,char transR,const FullMatrix<T> * m,const RkMatrix<T> * rk)731 RkMatrix<T>* RkMatrix<T>::multiplyFullRk(char transM, char transR,
732 const FullMatrix<T>* m,
733 const RkMatrix<T>* rk) {
734 DECLARE_CONTEXT;
735 // If transM is 'N' and transR is 'N', we compute
736 // M * A * B^T ==> newA = M * A, newB = B
737 // We can deduce all other cases from this one:
738 // * if transR is 'T', all we have to do is to swap A and B
739 // * if transR is 'C', we swap A and B, but they must also
740 // be conjugate; let us look at the different cases:
741 // + if transM is 'N', newA = M * conj(A)
742 // + if transM is 'T', newA = M^T * conj(A) = conj(M^H * A)
743 // + if transM is 'C', newA = M^H * conj(A) = conj(M^T * A)
744 ScalarArray<T> *newA, *newB;
745 ScalarArray<T>* ra = rk->a;
746 ScalarArray<T>* rb = rk->b;
747 if (transR != 'N') { // permutation to transpose the matrix Rk
748 std::swap(ra, rb);
749 }
750 const IndexSet *rkCols = ((transR == 'N')? rk->cols : rk->rows);
751 const IndexSet *mRows = ((transM == 'N')? m->rows_ : m->cols_);
752
753 newA = new ScalarArray<T>(mRows->size(), rb->cols);
754 newB = rb->copy();
755 if (transR == 'C') {
756 newB->conjugate();
757 if (transM == 'N') {
758 ScalarArray<T> *conjA = ra->copy();
759 conjA->conjugate();
760 newA->gemm('N', 'N', Constants<T>::pone, &m->data, conjA, Constants<T>::zero);
761 delete conjA;
762 } else if (transM == 'T') {
763 newA->gemm('C', 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
764 newA->conjugate();
765 } else {
766 assert(transM == 'C');
767 newA->gemm('T', 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
768 newA->conjugate();
769 }
770 } else {
771 newA->gemm(transM, 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
772 }
773 RkMatrix<T>* result = new RkMatrix<T>(newA, mRows, newB, rkCols);
774 return result;
775 }
776
777 template<typename T>
multiplyRkH(char transR,char transH,const RkMatrix<T> * rk,const HMatrix<T> * h)778 RkMatrix<T>* RkMatrix<T>::multiplyRkH(char transR, char transH,
779 const RkMatrix<T>* rk, const HMatrix<T>* h) {
780 DECLARE_CONTEXT;
781 assert(((transR == 'N') ? *rk->cols : *rk->rows) == ((transH == 'N')? *h->rows() : *h->cols()));
782
783 const IndexSet* rkRows = ((transR == 'N')? rk->rows : rk->cols);
784
785 // If transR == 'N'
786 // transM == 'N': (A*B^T)*M = A*(M^T*B)^T
787 // transM == 'T': (A*B^T)*M^T = A*(M*B)^T
788 // transM == 'C': (A*B^T)*M^H = A*(conj(M)*B)^T = A*conj(M*conj(B))^T
789 // If transR == 'T', we only have to swap A and B
790 // If transR == 'C', we swap A and B, then
791 // transM == 'N': R^H*M = conj(A)*(M^T*conj(B))^T = conj(A)*conj(M^H*B)^T
792 // transM == 'T': R^H*M^T = conj(A)*(M*conj(B))^T
793 // transM == 'C': R^H*M^H = conj(A)*conj(M*B)^T
794 //
795 // Size of the HMatrix is n x m,
796 // So H^t size is m x n and the product is m x cols(B)
797 // and the number of columns of B is k.
798 ScalarArray<T> *newA, *newB;
799 ScalarArray<T>* ra = rk->a;
800 ScalarArray<T>* rb = rk->b;
801 if (transR != 'N') { // permutation to transpose the matrix Rk
802 std::swap(ra, rb);
803 }
804
805 const IndexSet *newCols = ((transH == 'N' )? h->cols() : h->rows());
806
807 newA = ra->copy();
808 newB = new ScalarArray<T>(transH == 'N' ? h->cols()->size() : h->rows()->size(), rb->cols);
809 if (transR == 'C') {
810 newA->conjugate();
811 if (transH == 'N') {
812 h->gemv('C', Constants<T>::pone, rb, Constants<T>::zero, newB);
813 newB->conjugate();
814 } else if (transH == 'T') {
815 ScalarArray<T> *conjB = rb->copy();
816 conjB->conjugate();
817 h->gemv('N', Constants<T>::pone, conjB, Constants<T>::zero, newB);
818 delete conjB;
819 } else {
820 assert(transH == 'C');
821 h->gemv('N', Constants<T>::pone, rb, Constants<T>::zero, newB);
822 newB->conjugate();
823 }
824 } else {
825 if (transH == 'N') {
826 h->gemv('T', Constants<T>::pone, rb, Constants<T>::zero, newB);
827 } else if (transH == 'T') {
828 h->gemv('N', Constants<T>::pone, rb, Constants<T>::zero, newB);
829 } else {
830 assert(transH == 'C');
831 ScalarArray<T> *conjB = rb->copy();
832 conjB->conjugate();
833 h->gemv('N', Constants<T>::pone, conjB, Constants<T>::zero, newB);
834 delete conjB;
835 newB->conjugate();
836 }
837 }
838 RkMatrix<T>* result = new RkMatrix<T>(newA, rkRows, newB, newCols);
839 return result;
840 }
841
842 template<typename T>
multiplyHRk(char transH,char transR,const HMatrix<T> * h,const RkMatrix<T> * rk)843 RkMatrix<T>* RkMatrix<T>::multiplyHRk(char transH, char transR,
844 const HMatrix<T>* h, const RkMatrix<T>* rk) {
845
846 DECLARE_CONTEXT;
847 if (rk->rank() == 0) {
848 const IndexSet* newRows = ((transH == 'N') ? h-> rows() : h->cols());
849 const IndexSet* newCols = ((transR == 'N') ? rk->cols : rk->rows);
850 return new RkMatrix<T>(NULL, newRows, NULL, newCols);
851 }
852
853 // If transH is 'N' and transR is 'N', we compute
854 // M * A * B^T ==> newA = M * A, newB = B
855 // We can deduce all other cases from this one:
856 // * if transR is 'T', all we have to do is to swap A and B
857 // * if transR is 'C', we swap A and B, but they must also
858 // be conjugate; let us look at the different cases:
859 // + if transH is 'N', newA = M * conj(A)
860 // + if transH is 'T', newA = M^T * conj(A) = conj(M^H * A)
861 // + if transH is 'C', newA = M^H * conj(A) = conj(M^T * A)
862 ScalarArray<T> *newA, *newB;
863 ScalarArray<T>* ra = rk->a;
864 ScalarArray<T>* rb = rk->b;
865 if (transR != 'N') { // permutation to transpose the matrix Rk
866 std::swap(ra, rb);
867 }
868 const IndexSet *rkCols = ((transR == 'N')? rk->cols : rk->rows);
869 const IndexSet* newRows = ((transH == 'N')? h-> rows() : h->cols());
870
871 newA = new ScalarArray<T>(transH == 'N' ? h->rows()->size() : h->cols()->size(), rb->cols);
872 newB = rb->copy();
873 if (transR == 'C') {
874 newB->conjugate();
875 if (transH == 'N') {
876 ScalarArray<T> *conjA = ra->copy();
877 conjA->conjugate();
878 h->gemv('N', Constants<T>::pone, conjA, Constants<T>::zero, newA);
879 delete conjA;
880 } else if (transH == 'T') {
881 h->gemv('C', Constants<T>::pone, ra, Constants<T>::zero, newA);
882 newA->conjugate();
883 } else {
884 assert(transH == 'C');
885 h->gemv('T', Constants<T>::pone, ra, Constants<T>::zero, newA);
886 newA->conjugate();
887 }
888 } else {
889 h->gemv(transH, Constants<T>::pone, ra, Constants<T>::zero, newA);
890 }
891 RkMatrix<T>* result = new RkMatrix<T>(newA, newRows, newB, rkCols);
892 return result;
893 }
894
895 template<typename T>
multiplyRkRk(char trans1,char trans2,const RkMatrix<T> * r1,const RkMatrix<T> * r2,double epsilon)896 RkMatrix<T>* RkMatrix<T>::multiplyRkRk(char trans1, char trans2,
897 const RkMatrix<T>* r1, const RkMatrix<T>* r2, double epsilon) {
898 DECLARE_CONTEXT;
899 assert(((trans1 == 'N') ? *r1->cols : *r1->rows) == ((trans2 == 'N') ? *r2->rows : *r2->cols));
900 // It is possible to do the computation differently, yielding a
901 // different rank and a different amount of computation.
902 // TODO: choose the best order.
903 ScalarArray<T>* a1 = (trans1 == 'N' ? r1->a : r1->b);
904 ScalarArray<T>* b1 = (trans1 == 'N' ? r1->b : r1->a);
905 ScalarArray<T>* a2 = (trans2 == 'N' ? r2->a : r2->b);
906 ScalarArray<T>* b2 = (trans2 == 'N' ? r2->b : r2->a);
907
908 assert(b1->rows == a2->rows); // compatibility of the multiplication
909
910 // We want to compute the matrix a1.t^b1.a2.t^b2 and return an Rk matrix
911 // Usually, the best way is to start with tmp=t^b1.a2 which produces a 'small' matrix rank1 x rank2
912 //
913 // OLD version (default):
914 // Then we can either :
915 // - compute a1.tmp : the cost is rank1.rank2.row_a, the resulting Rk has rank rank2
916 // - compute tmp.t^b2 : the cost is rank1.rank2.col_b, the resulting Rk has rank rank1
917 // We use the solution which gives the lowest resulting rank.
918 // With this version, orthogonality is lost on one panel, it is preserved on the other.
919 //
920 // NEW version :
921 // Other solution: once we have the small matrix tmp=t^b1.a2, we can do a recompression on it for low cost
922 // using SVD + truncation. This also removes the choice above, since tmp=U.S.V is then applied on both sides
923 // This version is default, it can be deactivated by setting env. var. HMAT_OLD_RKRK
924 // With this version, orthogonality is lost on both panel.
925
926 ScalarArray<T> tmp(r1->rank(), r2->rank(), false);
927 if (trans1 == 'C' && trans2 == 'C') {
928 tmp.gemm('T', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
929 tmp.conjugate();
930 } else if (trans1 == 'C') {
931 tmp.gemm('C', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
932 } else if (trans2 == 'C') {
933 tmp.gemm('C', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
934 tmp.conjugate();
935 } else {
936 tmp.gemm('T', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
937 }
938
939 ScalarArray<T> *newA=NULL, *newB=NULL;
940 static char *oldRKRK = getenv("HMAT_OLD_RKRK"); // Option to use the OLD version, without SVD-in-the-middle
941 if (!oldRKRK) {
942 // NEW version
943 ScalarArray<T>* ur = NULL;
944 ScalarArray<T>* vr = NULL;
945 // truncated SVD tmp = ur.t^vr
946 int newK = tmp.truncatedSvdDecomposition(&ur, &vr, epsilon, true);
947 if (newK > 0) {
948 /* Now compute newA = a1.ur and newB = b2.vr */
949 newA = new ScalarArray<T>(a1->rows, newK, false);
950 if (trans1 == 'C') ur->conjugate();
951 newA->gemm('N', 'N', Constants<T>::pone, a1, ur, Constants<T>::zero);
952 if (trans1 == 'C') newA->conjugate();
953 newB = new ScalarArray<T>(b2->rows, newK, false);
954 if (trans2 == 'C') vr->conjugate();
955 newB->gemm('N', 'N', Constants<T>::pone, b2, vr, Constants<T>::zero);
956 if (trans2 == 'C') newB->conjugate();
957 delete ur;
958 delete vr;
959 }
960 } else {
961 // OLD version
962 if (r1->rank() < r2->rank()) {
963 // newA = a1, newB = b2.t^tmp
964 newA = a1->copy();
965 if (trans1 == 'C') newA->conjugate();
966 newB = new ScalarArray<T>(b2->rows, r1->rank());
967 if (trans2 == 'C') {
968 newB->gemm('N', 'C', Constants<T>::pone, b2, &tmp, Constants<T>::zero);
969 newB->conjugate();
970 } else {
971 newB->gemm('N', 'T', Constants<T>::pone, b2, &tmp, Constants<T>::zero);
972 }
973 } else { // newA = a1.tmp, newB = b2
974 newA = new ScalarArray<T>(a1->rows, r2->rank());
975 if (trans1 == 'C') tmp.conjugate(); // be careful if you re-use tmp after this...
976 newA->gemm('N', 'N', Constants<T>::pone, a1, &tmp, Constants<T>::zero);
977 if (trans1 == 'C') newA->conjugate();
978 newB = b2->copy();
979 if (trans2 == 'C') newB->conjugate();
980 }
981 }
982 return new RkMatrix<T>(newA, ((trans1 == 'N') ? r1->rows : r1->cols), newB, ((trans2 == 'N') ? r2->cols : r2->rows));
983 }
984
985 template<typename T>
computeRkRkMemorySize(char trans1,char trans2,const RkMatrix<T> * r1,const RkMatrix<T> * r2)986 size_t RkMatrix<T>::computeRkRkMemorySize(char trans1, char trans2,
987 const RkMatrix<T>* r1, const RkMatrix<T>* r2)
988 {
989 ScalarArray<T>* b2 = (trans2 == 'N' ? r2->b : r2->a);
990 ScalarArray<T>* a1 = (trans1 == 'N' ? r1->a : r1->b);
991 return b2 == NULL ? 0 : b2->memorySize() +
992 a1 == NULL ? 0 : a1->rows * r2->rank() * sizeof(T);
993 }
994
995 template<typename T>
multiplyWithDiagOrDiagInv(const HMatrix<T> * d,bool inverse,Side side)996 void RkMatrix<T>::multiplyWithDiagOrDiagInv(const HMatrix<T> * d, bool inverse, Side side) {
997 assert(*d->rows() == *d->cols());
998 assert(side == Side::RIGHT || (*rows == *d->cols()));
999 assert(side == Side::LEFT || (*cols == *d->rows()));
1000
1001 // extracting the diagonal
1002 Vector<T>* diag = new Vector<T>(d->cols()->size());
1003 d->extractDiagonal(diag->ptr());
1004
1005 // left multiplication by d of b (if M<-M*D : side = RIGHT) or a (if M<-D*M : side = LEFT)
1006 ScalarArray<T>* aOrB = (side == Side::LEFT ? a : b);
1007 aOrB->multiplyWithDiagOrDiagInv(diag, inverse, Side::LEFT);
1008
1009 delete diag;
1010 }
1011
gemmRk(double epsilon,char transHA,char transHB,T alpha,const HMatrix<T> * ha,const HMatrix<T> * hb)1012 template<typename T> void RkMatrix<T>::gemmRk(double epsilon, char transHA, char transHB,
1013 T alpha, const HMatrix<T>* ha, const HMatrix<T>* hb) {
1014 DECLARE_CONTEXT;
1015 if (!ha->isLeaf() && !hb->isLeaf()) {
1016 // Recursion case
1017 int nbRows = transHA == 'N' ? ha->nrChildRow() : ha->nrChildCol() ; /* Row blocks of the product */
1018 int nbCols = transHB == 'N' ? hb->nrChildCol() : hb->nrChildRow() ; /* Col blocks of the product */
1019 int nbCom = transHA == 'N' ? ha->nrChildCol() : ha->nrChildRow() ; /* Common dimension between A and B */
1020 int nSubRks = nbRows * nbCols;
1021 RkMatrix<T>* subRks[nSubRks];
1022 std::fill_n(subRks, nSubRks, nullptr);
1023 for (int i = 0; i < nbRows; i++) {
1024 for (int j = 0; j < nbCols; j++) {
1025 int p = i + j * nbRows;
1026 for (int k = 0; k < nbCom; k++) {
1027 // C_ij = A_ik * B_kj
1028 HMatrix<T>* a_ik = transHA == 'N' ? ha->get(i, k) : ha->get(k, i);
1029 HMatrix<T>* b_kj = transHB == 'N' ? hb->get(k, j) : hb->get(j, k);
1030 if (a_ik && b_kj) {
1031 if (subRks[p] == nullptr) {
1032 const IndexSet* subRows = transHA == 'N' ? a_ik->rows() : a_ik->cols();
1033 const IndexSet* subCols = transHB == 'N' ? b_kj->cols() : b_kj->rows();
1034 subRks[p] = new RkMatrix<T>(nullptr, subRows, nullptr, subCols);
1035 }
1036 subRks[p]->gemmRk(epsilon, transHA, transHB, alpha, a_ik, b_kj);
1037 }
1038 } // k loop
1039 } // j loop
1040 } // i loop
1041 // Reconstruction of C by adding the parts
1042 T alphaV[nSubRks];
1043 std::fill_n(alphaV, nSubRks, 1);
1044 formattedAddParts(epsilon, alphaV, subRks, nSubRks);
1045 for (int i = 0; i < nSubRks; i++) {
1046 delete subRks[i];
1047 }
1048 } else {
1049 RkMatrix<T>* rk = nullptr;
1050 // One of the product matrix is a leaf
1051 if ((ha->isLeaf() && ha->isNull()) || (hb->isLeaf() && hb->isNull())) {
1052 // Nothing to do
1053 } else if (ha->isRkMatrix() || hb->isRkMatrix()) {
1054 rk = HMatrix<T>::multiplyRkMatrix(epsilon, transHA, transHB, ha, hb);
1055 } else {
1056 assert(ha->isFullMatrix() || hb->isFullMatrix());
1057 FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix(transHA, transHB, ha, hb);
1058 if(fullMat) {
1059 rk = truncatedSvd(fullMat, epsilon); // TODO compress with something else than SVD
1060 delete fullMat;
1061 }
1062 }
1063 if(rk) {
1064 if(rank() == 0) {
1065 // save memory by not allocating a temporary Rk
1066 rk->scale(alpha);
1067 swap(*rk);
1068 } else
1069 axpy(epsilon, alpha, rk);
1070 delete rk;
1071 }
1072 }
1073 }
1074
copy(const RkMatrix<T> * o)1075 template<typename T> void RkMatrix<T>::copy(const RkMatrix<T>* o) {
1076 delete a;
1077 delete b;
1078 rows = o->rows;
1079 cols = o->cols;
1080 a = (o->a ? o->a->copy() : NULL);
1081 b = (o->b ? o->b->copy() : NULL);
1082 }
1083
copy() const1084 template<typename T> RkMatrix<T>* RkMatrix<T>::copy() const {
1085 RkMatrix<T> *result = new RkMatrix<T>(NULL, rows, NULL, cols);
1086 result->copy(this);
1087 return result;
1088 }
1089
1090
checkNan() const1091 template<typename T> void RkMatrix<T>::checkNan() const {
1092 if (rank() == 0) {
1093 return;
1094 }
1095 a->checkNan();
1096 b->checkNan();
1097 }
1098
conjugate()1099 template<typename T> void RkMatrix<T>::conjugate() {
1100 if (a) a->conjugate();
1101 if (b) b->conjugate();
1102 }
1103
get(int i,int j) const1104 template<typename T> T RkMatrix<T>::get(int i, int j) const {
1105 return a->dot_aibj(i, *b, j);
1106 }
1107
writeArray(hmat_iostream writeFunc,void * userData) const1108 template<typename T> void RkMatrix<T>::writeArray(hmat_iostream writeFunc, void * userData) const{
1109 a->writeArray(writeFunc, userData);
1110 b->writeArray(writeFunc, userData);
1111 }
1112
1113 template <typename T>
1114 bool (*RkMatrix<T>::formatedAddPartsHook)(RkMatrix<T> *me, double epsilon, const T *alpha,
1115 const RkMatrix<T> *const *parts,
1116 const int n) = NULL;
1117
1118 // Templates declaration
1119 template class RkMatrix<S_t>;
1120 template class RkMatrix<D_t>;
1121 template class RkMatrix<C_t>;
1122 template class RkMatrix<Z_t>;
1123
1124 } // end namespace hmat
1125