1 /*
2   HMat-OSS (HMatrix library, open source software)
3 
4   Copyright (C) 2014-2015 Airbus Group SAS
5 
6   This program is free software; you can redistribute it and/or
7   modify it under the terms of the GNU General Public License
8   as published by the Free Software Foundation; either version 2
9   of the License, or (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19 
20   http://github.com/jeromerobert/hmat-oss
21 */
22 
23 #include "rk_matrix.hpp"
24 #include "h_matrix.hpp"
25 #include "cluster_tree.hpp"
26 #include <cstring> // memcpy
27 #include <cfloat> // DBL_MAX
28 #include "data_types.hpp"
29 #include "lapack_operations.hpp"
30 #include "blas_overloads.hpp"
31 #include "lapack_overloads.hpp"
32 #include "common/context.hpp"
33 #include "common/my_assert.h"
34 #include "common/timeline.hpp"
35 #include "lapack_exception.hpp"
36 
37 namespace hmat {
38 
39 /** RkApproximationControl */
40 template<typename T> RkApproximationControl RkMatrix<T>::approx;
41 
42 /** RkMatrix */
RkMatrix(ScalarArray<T> * _a,const IndexSet * _rows,ScalarArray<T> * _b,const IndexSet * _cols)43 template<typename T> RkMatrix<T>::RkMatrix(ScalarArray<T>* _a, const IndexSet* _rows,
44                                            ScalarArray<T>* _b, const IndexSet* _cols)
45   : rows(_rows),
46     cols(_cols),
47     a(_a),
48     b(_b)
49 {
50 
51   // We make a special case for empty matrices.
52   if ((!a) && (!b)) {
53     return;
54   }
55   assert(a->rows == rows->size());
56   assert(b->rows == cols->size());
57 }
58 
~RkMatrix()59 template<typename T> RkMatrix<T>::~RkMatrix() {
60   clear();
61 }
62 
63 
evalArray(ScalarArray<T> * result) const64 template<typename T> ScalarArray<T>* RkMatrix<T>::evalArray(ScalarArray<T>* result) const {
65   if(result==NULL)
66     result = new ScalarArray<T>(rows->size(), cols->size());
67   if (rank())
68     result->gemm('N', 'T', Constants<T>::pone, a, b, Constants<T>::zero);
69   else
70     result->clear();
71   return result;
72 }
73 
eval() const74 template<typename T> FullMatrix<T>* RkMatrix<T>::eval() const {
75   FullMatrix<T>* result = new FullMatrix<T>(rows, cols, false);
76   evalArray(&result->data);
77   return result;
78 }
79 
80 // Compute squared Frobenius norm
normSqr() const81 template<typename T> double RkMatrix<T>::normSqr() const {
82   return a->norm_abt_Sqr(*b);
83 }
84 
scale(T alpha)85 template<typename T> void RkMatrix<T>::scale(T alpha) {
86   // We need just to scale the first matrix, A.
87   if (a) {
88     a->scale(alpha);
89   }
90 }
91 
transpose()92 template<typename T> void RkMatrix<T>::transpose() {
93   std::swap(a, b);
94   std::swap(rows, cols);
95 }
96 
clear()97 template<typename T> void RkMatrix<T>::clear() {
98   delete a;
99   delete b;
100   a = NULL;
101   b = NULL;
102 }
103 
104 template<typename T>
gemv(char trans,T alpha,const ScalarArray<T> * x,T beta,ScalarArray<T> * y,Side side) const105 void RkMatrix<T>::gemv(char trans, T alpha, const ScalarArray<T>* x, T beta, ScalarArray<T>* y, Side side) const {
106   if (rank() == 0) {
107     if (beta != Constants<T>::pone) {
108       y->scale(beta);
109     }
110     return;
111   }
112   if (side == Side::LEFT) {
113     if (trans == 'N') {
114       // Compute Y <- Y + alpha * A * B^T * X
115       ScalarArray<T> z(b->cols, x->cols);
116       z.gemm('T', 'N', Constants<T>::pone, b, x, Constants<T>::zero);
117       y->gemm('N', 'N', alpha, a, &z, beta);
118     } else if (trans == 'T') {
119       // Compute Y <- Y + alpha * B * A^T * X
120       ScalarArray<T> z(a->cols, x->cols);
121       z.gemm('T', 'N', Constants<T>::pone, a, x, Constants<T>::zero);
122       y->gemm('N', 'N', alpha, b, &z, beta);
123     } else {
124       assert(trans == 'C');
125       // Compute Y <- Y + alpha * (A*B^T)^H * X = Y + alpha * conj(B) * A^H * X
126       ScalarArray<T> z(a->cols, x->cols);
127       z.gemm('C', 'N', Constants<T>::pone, a, x, Constants<T>::zero);
128       ScalarArray<T> * newB = b->copy();
129       newB->conjugate();
130       y->gemm('N', 'N', alpha, newB, &z, beta);
131       delete newB;
132     }
133   } else {
134     if (trans == 'N') {
135       // Compute Y <- Y + alpha * X * A * B^T
136       ScalarArray<T> z(x->rows, a->cols);
137       z.gemm('N', 'N', Constants<T>::pone, x, a, Constants<T>::zero);
138       y->gemm('N', 'T', alpha, &z, b, beta);
139     } else if (trans == 'T') {
140       // Compute Y <- Y + alpha * X * B * A^T
141       ScalarArray<T> z(x->rows, b->cols);
142       z.gemm('N', 'N', Constants<T>::pone, x, b, Constants<T>::zero);
143       y->gemm('N', 'T', alpha, &z, a, beta);
144     } else {
145       assert(trans == 'C');
146       // Compute Y <- Y + alpha * X * (A*B^T)^H = Y + alpha * X * conj(B) * A^H
147       ScalarArray<T> * newB = b->copy();
148       newB->conjugate();
149       ScalarArray<T> z(x->rows, b->cols);
150       z.gemm('N', 'N', Constants<T>::pone, x, newB, Constants<T>::zero);
151       delete newB;
152       y->gemm('N', 'C', alpha, &z, a, beta);
153     }
154   }
155 }
156 
subset(const IndexSet * subRows,const IndexSet * subCols) const157 template<typename T> const RkMatrix<T>* RkMatrix<T>::subset(const IndexSet* subRows,
158                                                             const IndexSet* subCols) const {
159   assert(subRows->isSubset(*rows));
160   assert(subCols->isSubset(*cols));
161   ScalarArray<T>* subA = NULL;
162   ScalarArray<T>* subB = NULL;
163   if(rank() > 0) {
164     // The offset in the matrix, and not in all the indices
165     int rowsOffset = subRows->offset() - rows->offset();
166     int colsOffset = subCols->offset() - cols->offset();
167     subA = new ScalarArray<T>(*a, rowsOffset, subRows->size(), 0, rank());
168     subB = new ScalarArray<T>(*b, colsOffset, subCols->size(), 0, rank());
169   }
170   return new RkMatrix<T>(subA, subRows, subB, subCols);
171 }
172 
truncatedSubset(const IndexSet * subRows,const IndexSet * subCols,double epsilon) const173 template<typename T> RkMatrix<T>* RkMatrix<T>::truncatedSubset(const IndexSet* subRows,
174                                                                const IndexSet* subCols,
175                                                                double epsilon) const {
176   assert(subRows->isSubset(*rows));
177   assert(subCols->isSubset(*cols));
178   RkMatrix<T> * r = new RkMatrix<T>(NULL, subRows, NULL, subCols);
179   if(rank() > 0) {
180     r->a = ScalarArray<T>(*a, subRows->offset() - rows->offset(),
181                           subRows->size(), 0, rank()).copy();
182     r->b = ScalarArray<T>(*b, subCols->offset() - cols->offset(),
183                           subCols->size(), 0, rank()).copy();
184     if(epsilon >= 0)
185       r->truncate(epsilon);
186   }
187   return r;
188 }
189 
compressedSize()190 template<typename T> size_t RkMatrix<T>::compressedSize() {
191     return ((size_t)rows->size()) * rank() + ((size_t)cols->size()) * rank();
192 }
193 
uncompressedSize()194 template<typename T> size_t RkMatrix<T>::uncompressedSize() {
195     return ((size_t)rows->size()) * cols->size();
196 }
197 
addRand(double epsilon)198 template<typename T> void RkMatrix<T>::addRand(double epsilon) {
199   DECLARE_CONTEXT;
200   a->addRand(epsilon);
201   b->addRand(epsilon);
202   return;
203 }
204 
205 /**
206  * @brief Truncate the A or B block of a RkMatrix.
207  * This is a utilitary function for RkMatrix<T>::truncated and RkMatrix<T>::truncate
208  * @param ab The A or B block to truncate
209  * @param indexSet The index set of the A or B block to truncate (rows for A and cols for B)
210  * @param newK The rank to truncate to (output of the SVD)
211  * @param uv The U or V matrix (output of the SVD)
212  * @return the truncated A or B block
213  */
214 template <typename T>
truncatedAB(ScalarArray<T> * ab,const IndexSet * indexSet,int newK,ScalarArray<T> * uv,bool useInitPivot=false,int initialPivot=0)215 ScalarArray<T> *truncatedAB(ScalarArray<T> *ab, const IndexSet *indexSet,
216                             int newK, ScalarArray<T> *uv,
217                             bool useInitPivot = false, int initialPivot = 0) {
218   // We need to calculate Qa * u
219   ScalarArray<T>* newAB = new ScalarArray<T>(indexSet->size(), newK);
220   if (useInitPivot && initialPivot) {
221     // If there is an initial pivot, we must compute the product by Q in two parts
222     // first the column >= initialPivotA, obtained from lapack GETRF, will overwrite newA when calling UNMQR
223     // then the first initialPivotA columns, with a classical GEMM, will add the result in newA
224 
225     // create subset of a (columns>=initialPivotA) and u (rows>=initialPivotA)
226     ScalarArray<T> sub_ab(*ab, 0, ab->rows, initialPivot, ab->cols-initialPivot);
227     ScalarArray<T> sub_uv(* uv, initialPivot,  uv->rows-initialPivot, 0,  uv->cols);
228     newAB->copyMatrixAtOffset(&sub_uv, 0, 0);
229     // newA <- Qa * newA (with newA = u)
230     sub_ab.productQ('L', 'N', newAB);
231 
232     // then add the regular part of the product by Q
233     ScalarArray<T> sub_ab2(*ab, 0, ab->rows, 0, initialPivot);
234     ScalarArray<T> sub_uv2(* uv, 0, initialPivot, 0,  uv->cols);
235     newAB->gemm('N', 'N', Constants<T>::pone, &sub_ab2, &sub_uv2, Constants<T>::pone);
236   } else {
237     // If no initialPivotA, then no gemm, just a productQ()
238     newAB->copyMatrixAtOffset( uv, 0, 0);
239     // newA <- Qa * newA
240     ab->productQ('L', 'N', newAB);
241   }
242 
243   newAB->setOrtho( uv->getOrtho());
244   delete  uv;
245   return newAB;
246 }
247 
truncate(double epsilon,int initialPivotA,int initialPivotB)248 template<typename T> void RkMatrix<T>::truncate(double epsilon, int initialPivotA, int initialPivotB) {
249   DECLARE_CONTEXT;
250 
251   if (rank() == 0) {
252     assert(!(a || b));
253     return;
254   }
255 
256   assert(rows->size() >= rank());
257   // Case: more columns than one dimension of the matrix.
258   // In this case, the calculation of the SVD of the matrix "R_a R_b^t" is more
259   // expensive than computing the full SVD matrix. We make then a full matrix conversion,
260   // and compress it with RkMatrix::fromMatrix().
261   if (rank() > std::min(rows->size(), cols->size())) {
262     FullMatrix<T>* tmp = eval();
263     RkMatrix<T>* rk = truncatedSvd(tmp, epsilon); // TODO compress with something else than SVD (rank() can still be quite large) ?
264     delete tmp;
265     // "Move" rk into this, and delete the old "this".
266     swap(*rk);
267     delete rk;
268     return;
269   }
270 
271   static bool usedRecomp = getenv("HMAT_RECOMPRESS") && strcmp(getenv("HMAT_RECOMPRESS"), "MGS") == 0 ;
272   if (usedRecomp){
273     mGSTruncate(epsilon, initialPivotA, initialPivotB);
274     return;
275   }
276 
277   /* To recompress an Rk-matrix to Rk-matrix, we need :
278       - A = Q_a R_A (QR decomposition)
279       - B = Q_b R_b (QR decomposition)
280       - Calculate the SVD of R_a R_b^t  = U S V^t
281       - Make truncation U, S, and V in the same way as for
282       compression of a full rank matrix, ie:
283       - Restrict U to its newK first columns U_tilde
284       - Restrict S to its newK first values (diagonal matrix) S_tilde
285       - Restrict V to its newK first columns V_tilde
286       - Put A = Q_a U_tilde SQRT (S_tilde)
287       B = Q_b V_tilde SQRT(S_tilde)
288 
289      The sizes of the matrices are:
290       - Qa : rows x k
291       - Ra : k x k
292       - Qb : cols x k
293       - Rb : k x k
294      So:
295       - Ra Rb^t: k x k
296       - U  : k * k
297       - S  : k (diagonal)
298       - V^t: k * k
299      Hence:
300       - newA: rows x newK
301       - newB: cols x newK
302 
303   */
304 
305   // truncated SVD of Ra Rb^t (allows failure)
306   ScalarArray<T> *u = NULL, *v = NULL;
307   int newK;
308   // context block to release ra, rb, r ASAP
309   {
310     // QR decomposition of A and B
311     ScalarArray<T> ra(rank(), rank());
312     a->qrDecomposition(&ra, initialPivotA); // A contains Qa and tau_a
313     ScalarArray<T> rb(rank(), rank());
314     b->qrDecomposition(&rb, initialPivotB); // B contains Qb and tau_b
315 
316     // R <- Ra Rb^t
317     ScalarArray<T> r(rank(), rank());
318     r.gemm('N','T', Constants<T>::pone, &ra, &rb , Constants<T>::zero);
319 
320     // truncated SVD of Ra Rb^t (allows failure)
321     newK = r.truncatedSvdDecomposition(&u, &v, epsilon, true); // TODO use something else than SVD ?
322   }
323   if (newK == 0) {
324     clear();
325     return;
326   }
327   // We need to know if qrDecomposition has used initPivot...
328   // (Not so great, because HMAT_TRUNC_INITPIV is checked at 2 different locations)
329   static char *useInitPivot = getenv("HMAT_TRUNC_INITPIV");
330   ScalarArray<T>* newA = truncatedAB(a, rows, newK, u, useInitPivot, initialPivotA);
331   delete a;
332   a = newA;
333   ScalarArray<T>* newB = truncatedAB(b, cols, newK, v, useInitPivot, initialPivotB);
334   delete b;
335   b = newB;
336 }
337 
mGSTruncate(double epsilon,int initialPivotA,int initialPivotB)338 template<typename T> void RkMatrix<T>::mGSTruncate(double epsilon, int initialPivotA, int initialPivotB) {
339   DECLARE_CONTEXT;
340   if (rank() == 0) {
341     assert(!(a || b));
342     return;
343   }
344 
345   int kA, kB, newK;
346 
347   int krank = rank();
348 
349   // Gram-Schmidt on a
350   // On input, a0(m,k)
351   // On output, a(m,kA), ra(kA,k) such that a0 = a * ra
352   ScalarArray<T> ra(krank, krank);
353   kA = a->modifiedGramSchmidt( &ra, epsilon, initialPivotA);
354   if (kA==0) {
355     clear();
356     return;
357   }
358 
359   // Gram-Schmidt on b
360   // On input, b0(p,k)
361   // On output, b(p,kB), rb(kB,k) such that b0 = b * rb
362   ScalarArray<T> rb(krank, krank);
363   kB = b->modifiedGramSchmidt( &rb, epsilon, initialPivotB);
364   if (kB==0) {
365     clear();
366     return;
367   }
368 
369   // M = a0*b0^T = a*(ra*rb^T)*b^T
370   // We perform an SVD on ra*rb^T:
371   //  (ra*rb^T) = U*S*S*Vt
372   // and M = (a*U*S)*(S*Vt*b^T) = (a*U*S)*(b*(S*Vt)^T)^T
373   ScalarArray<T> matR(kA, kB);
374   matR.gemm('N','T', Constants<T>::pone, &ra, &rb , Constants<T>::zero);
375 
376   // truncatedSVD (allows failure)
377   ScalarArray<T>* ur = NULL;
378   ScalarArray<T>* vr = NULL;
379   newK = matR.truncatedSvdDecomposition(&ur, &vr, epsilon, true);
380   // On output, ur->rows = kA, vr->rows = kB
381 
382   if (newK == 0) {
383     clear();
384     return;
385   }
386 
387   /* Multiplication by orthogonal matrix Q: no or/un-mqr as
388     this comes from Gram-Schmidt procedure not Householder
389   */
390   ScalarArray<T> *newA = new ScalarArray<T>(a->rows, newK);
391   newA->gemm('N', 'N', Constants<T>::pone, a, ur, Constants<T>::zero);
392 
393   ScalarArray<T> *newB = new ScalarArray<T>(b->rows, newK);
394   newB->gemm('N', 'N', Constants<T>::pone, b, vr, Constants<T>::zero);
395 
396   newA->setOrtho(ur->getOrtho());
397   newB->setOrtho(vr->getOrtho());
398   delete ur;
399   delete vr;
400 
401   delete a;
402   a = newA;
403   delete b;
404   b = newB;
405 }
406 
407 // Swap members with members from another instance.
swap(RkMatrix<T> & other)408 template<typename T> void RkMatrix<T>::swap(RkMatrix<T>& other)
409 {
410   assert(*rows == *other.rows);
411   assert(*cols == *other.cols);
412   std::swap(a, other.a);
413   std::swap(b, other.b);
414 }
415 
axpy(double epsilon,T alpha,const FullMatrix<T> * mat)416 template<typename T> void RkMatrix<T>::axpy(double epsilon, T alpha, const FullMatrix<T>* mat) {
417   formattedAddParts(epsilon, &alpha, &mat, 1);
418 }
419 
axpy(double epsilon,T alpha,const RkMatrix<T> * mat)420 template<typename T> void RkMatrix<T>::axpy(double epsilon, T alpha, const RkMatrix<T>* mat) {
421   formattedAddParts(epsilon, &alpha, &mat, 1);
422 }
423 
424 /*! \brief Try to optimize the order of the Rk matrix to maximize initialPivot
425 
426   We re-order the Rk matrices in usedParts[] and the associated constants in usedAlpha[]
427   in order to maximize the number of orthogonal columns starting at column 0 in the A and B
428   panels.
429   \param[in] notNullParts number of elements in usedParts[] and usedAlpha[]
430   \param[in,out] usedParts Array of Rk matrices
431   \param[in,out] usedAlpha Array of constants
432   \param[out] initialPivotA Number of orthogonal columns starting at column 0 in panel A
433   \param[out] initialPivotB Number of orthogonal columns starting at column 0 in panel A
434 */
435 template<typename T>
optimizeRkArray(int notNullParts,const RkMatrix<T> ** usedParts,T * usedAlpha,int & initialPivotA,int & initialPivotB)436 static void optimizeRkArray(int notNullParts, const RkMatrix<T>** usedParts, T *usedAlpha, int &initialPivotA, int &initialPivotB){
437   // 1st optim: Put in first position the Rk matrix with orthogonal panels AND maximum rank
438   int bestRk=-1, bestGain=-1;
439   for (int i=0 ; i<notNullParts ; i++) {
440     // Roughly, the gain from an initial pivot 'p' in a QR factorisation 'm x n' is to reduce the flops
441     // from 2mn^2 to 2m(n^2-p^2), so the gain grows like p^2 for each panel
442     // hence the gain formula : number of orthogonal panels x rank^2
443     int gain = (usedParts[i]->a->getOrtho() + usedParts[i]->b->getOrtho())*usedParts[i]->rank()*usedParts[i]->rank();
444     if (gain > bestGain) {
445       bestGain = gain;
446       bestRk = i;
447     }
448   }
449   if (bestRk > 0) {
450     std::swap(usedParts[0], usedParts[bestRk]) ;
451     std::swap(usedAlpha[0], usedAlpha[bestRk]) ;
452   }
453   initialPivotA = usedParts[0]->a->getOrtho() ? usedParts[0]->rank() : 0;
454   initialPivotB = usedParts[0]->b->getOrtho() ? usedParts[0]->rank() : 0;
455 
456   // 2nd optim:
457   // When coallescing Rk from childs toward parent, it is possible to "merge" Rk from (extra-)diagonal
458   // childs because with non-intersecting rows and cols we will extend orthogonality between separate Rk.
459   int best_i1=-1, best_i2=-1, best_rkA=-1, best_rkB=-1;
460   for (int i1=0 ; i1<notNullParts ; i1++)
461     for (int i2=0 ; i2<notNullParts ; i2++)
462       if (i1 != i2) {
463         const RkMatrix<T>* Rk1 = usedParts[i1];
464         const RkMatrix<T>* Rk2 = usedParts[i2];
465         // compute the gain expected from puting Rk1-Rk2 in first position
466         // Orthogonality of Rk2->a is useful only if Rk1->a is ortho AND rows dont intersect (cols for panel b)
467         int rkA = Rk1->a->getOrtho() ? Rk1->rank() + (Rk2->a->getOrtho() && !Rk1->rows->intersects(*Rk2->rows) ? Rk2->rank() : 0) : 0;
468         int rkB = Rk1->b->getOrtho() ? Rk1->rank() + (Rk2->b->getOrtho() && !Rk1->cols->intersects(*Rk2->cols) ? Rk2->rank() : 0) : 0;
469         int gain = rkA*rkA + rkB*rkB ;
470         if (gain > bestGain) {
471           bestGain = gain;
472           best_i1 = i1;
473           best_i2 = i2;
474           best_rkA = rkA;
475           best_rkB = rkB;
476         }
477       }
478 
479   if (best_i1 >= 0) {
480     // put i1 in first position, i2 in second
481     std::swap(usedParts[0], usedParts[best_i1]) ;
482     std::swap(usedAlpha[0], usedAlpha[best_i1]) ;
483     if (best_i2==0) best_i2 = best_i1; // handles the case where best_i2 was usedParts[0] which has just been moved
484     std::swap(usedParts[1], usedParts[best_i2]) ;
485     std::swap(usedAlpha[1], usedAlpha[best_i2]) ;
486     initialPivotA = best_rkA;
487     initialPivotB = best_rkB;
488   }
489 
490 }
491 
allSame(const RkMatrix<T> ** rks,int n)492 template<typename T> bool allSame(const RkMatrix<T>** rks, int n) {
493   for(int i = 1; i < n; i++) {
494     if(!(*rks[0]->rows == *rks[i]->rows) || !(*rks[0]->cols == *rks[i]->cols))
495       return false;
496   }
497   return true;
498 }
499 
500 template<typename T>
formattedAddParts(double epsilon,const T * alpha,const RkMatrix<T> * const * parts,const int n,bool hook)501 void RkMatrix<T>::formattedAddParts(double epsilon, const T* alpha, const RkMatrix<T>* const * parts,
502                                     const int n, bool hook) {
503   if(hook && formatedAddPartsHook && formatedAddPartsHook(this, epsilon, alpha, parts, n))
504     return;
505   // TODO check if formattedAddParts() actually uses sometimes this 'alpha' parameter (or is it always 1 ?)
506   DECLARE_CONTEXT;
507 
508   /* List of non-null and non-empty Rk matrices to coalesce, and the corresponding scaling coefficients */
509   const RkMatrix<T>* usedParts[n+1];
510   T usedAlpha[n+1];
511   /* Number of elements in usedParts[] */
512   int notNullParts = 0;
513   /* Sum of the ranks */
514   int rankTotal = 0;
515 
516   // If needed, put 'this' in first position in usedParts[]
517   if (rank()) {
518     usedAlpha[0] = Constants<T>::pone ;
519     usedParts[notNullParts++] = this ;
520     rankTotal += rank();
521   }
522 
523   for (int i = 0; i < n; i++) {
524     // exclude the NULL and 0-rank matrices
525     if (!parts[i] || parts[i]->rank() == 0 || parts[i]->rows->size() == 0 || parts[i]->cols->size() == 0 || alpha[i]==Constants<T>::zero)
526       continue;
527     // Check that partial RkMatrix indices are subsets of their global indices set.
528     assert(parts[i]->rows->isSubset(*rows));
529     assert(parts[i]->cols->isSubset(*cols));
530     // Add this Rk to the list
531     rankTotal += parts[i]->rank();
532     usedAlpha[notNullParts] = alpha[i] ;
533     usedParts[notNullParts] = parts[i] ;
534     notNullParts++;
535   }
536 
537   if(notNullParts == 0)
538     return;
539 
540   // In case the sum of the ranks of the sub-matrices is greater than
541   // the matrix size, it is more efficient to put everything in a
542   // full matrix.
543   if (rankTotal >= std::min(rows->size(), cols->size())) {
544     const FullMatrix<T>** fullParts = new const FullMatrix<T>*[notNullParts];
545     fullParts[0] = NULL ;
546     for (int i = rank() ? 1 : 0 ; i < notNullParts; i++) // exclude usedParts[0] if it is 'this'
547       fullParts[i] = usedParts[i]->eval();
548     formattedAddParts(std::abs(epsilon), usedAlpha, fullParts, notNullParts);
549     for (int i = 0; i < notNullParts; i++)
550       delete fullParts[i];
551     delete[] fullParts;
552     return;
553   }
554 
555   // Find if the QR factorisation can be accelerated using orthogonality information
556   int initialPivotA = usedParts[0]->a->getOrtho() ? usedParts[0]->rank() : 0;
557   int initialPivotB = usedParts[0]->b->getOrtho() ? usedParts[0]->rank() : 0;
558 
559   // Try to optimize the order of the Rk matrix to maximize initialPivot
560   static char *useBestRk = getenv("HMAT_MGS_BESTRK");
561   if (useBestRk)
562     optimizeRkArray(notNullParts, usedParts, usedAlpha, initialPivotA, initialPivotB);
563 
564   // According to the indices organization, the sub-matrices are
565   // contiguous blocks in the "big" matrix whose columns offset is
566   //      kOffset = usedParts[0]->k + ... + usedParts[i-1]->k
567   // rows offset is
568   //      usedParts[i]->rows->offset - rows->offset
569   // rows size
570   //      usedParts[i]->rows->size x usedParts[i]->k (rows x columns)
571   // Same for columns.
572 
573   // when possible realloc this a & b arrays to limit memory usage and avoid a copy
574   bool useRealloc = usedParts[0] == this && allSame(usedParts, notNullParts);
575   // concatenate a(i) then b(i) to limite memory usage
576   ScalarArray<T>* resultA, *resultB;
577   int rankOffset;
578   if(useRealloc) {
579     resultA = a;
580     rankOffset = a->cols;
581     a->resize(rankTotal);
582   }
583   else {
584     rankOffset = 0;
585     resultA = new ScalarArray<T>(rows->size(), rankTotal);
586   }
587 
588   for (int i = useRealloc ? 1 : 0; i < notNullParts; i++) {
589     // Copy 'a' at position rowOffset, kOffset
590     int rowOffset = usedParts[i]->rows->offset() - rows->offset();
591     resultA->copyMatrixAtOffset(usedParts[i]->a, rowOffset, rankOffset);
592     // Scaling the matrix already in place inside resultA
593     if (usedAlpha[i] != Constants<T>::pone) {
594       ScalarArray<T> tmp(*resultA, rowOffset, usedParts[i]->a->rows, rankOffset, usedParts[i]->a->cols);
595       tmp.scale(usedAlpha[i]);
596     }
597     // Update the rank offset
598     rankOffset += usedParts[i]->rank();
599   }
600   assert(rankOffset==rankTotal);
601 
602   if(!useRealloc && a != NULL)
603     delete a;
604   a = resultA;
605 
606   if(useRealloc) {
607     resultB = b;
608     rankOffset = b->cols;
609     b->resize(rankTotal);
610   }
611   else {
612     rankOffset = 0;
613     resultB = new ScalarArray<T>(cols->size(), rankTotal);
614   }
615 
616   for (int i = useRealloc ? 1 : 0; i < notNullParts; i++) {
617     // Copy 'b' at position colOffset, kOffset
618     int colOffset = usedParts[i]->cols->offset() - cols->offset();
619     resultB->copyMatrixAtOffset(usedParts[i]->b, colOffset, rankOffset);
620     // Update the rank offset
621     rankOffset += usedParts[i]->b->cols;
622   }
623 
624   if(!useRealloc && b != NULL)
625     delete b;
626   b = resultB;
627 
628   assert(rankOffset==rankTotal);
629   // If only one of the parts is non-zero, then the recompression is not necessary
630   if (notNullParts > 1 && epsilon >= 0)
631     truncate(epsilon, initialPivotA, initialPivotB);
632 }
633 
634 template<typename T>
formattedAddParts(double epsilon,const T * alpha,const FullMatrix<T> * const * parts,int n)635 void RkMatrix<T>::formattedAddParts(double epsilon, const T* alpha, const FullMatrix<T>* const * parts, int n) {
636   DECLARE_CONTEXT;
637   FullMatrix<T>* me = eval();
638   HMAT_ASSERT(me);
639 
640   // TODO: here, we convert Rk->Full, Update the Full with parts[], and Full->Rk. We could also
641   // create a new empty Full, update, convert to Rk and add it to 'this'.
642   // If the parts[] are smaller than 'this', convert them to Rk and add them could be less expensive
643   for (int i = 0; i < n; i++) {
644     if (!parts[i])
645       continue;
646     const IndexSet *rows_full = parts[i]->rows_;
647     const IndexSet *cols_full = parts[i]->cols_;
648     assert(rows_full->isSubset(*rows));
649     assert(cols_full->isSubset(*cols));
650     int rowOffset = rows_full->offset() - rows->offset();
651     int colOffset = cols_full->offset() - cols->offset();
652     int maxCol = cols_full->size();
653     int maxRow = rows_full->size();
654     ScalarArray<T> sub(me->data, rowOffset, maxRow, colOffset, maxCol);
655     sub.axpy(alpha[i], &parts[i]->data);
656   }
657   RkMatrix<T>* result = truncatedSvd(me, epsilon); // TODO compress with something else than SVD
658   delete me;
659   swap(*result);
660   delete result;
661 }
662 
663 
multiplyRkFull(char transR,char transM,const RkMatrix<T> * rk,const FullMatrix<T> * m)664 template<typename T> RkMatrix<T>* RkMatrix<T>::multiplyRkFull(char transR, char transM,
665                                                               const RkMatrix<T>* rk,
666                                                               const FullMatrix<T>* m) {
667   DECLARE_CONTEXT;
668 
669   assert(((transR == 'N') ? rk->cols->size() : rk->rows->size()) == ((transM == 'N') ? m->rows() : m->cols()));
670   const IndexSet *rkRows = ((transR == 'N')? rk->rows : rk->cols);
671   const IndexSet *mCols = ((transM == 'N')? m->cols_ : m->rows_);
672 
673   if(rk->rank() == 0) {
674       return new RkMatrix<T>(NULL, rkRows, NULL, mCols);
675   }
676   // If transM is 'N' and transR is 'N', we compute
677   //  A * B^T * M ==> newA = A, newB = M^T * B
678   // We can deduce all other cases from this one:
679   //   * if transR is 'T', all we have to do is to swap A and B
680   //   * if transR is 'C', we swap A and B, but they must also
681   //     be conjugate; let us look at the different cases:
682   //     + if transM is 'N', newB = M^T * conj(B) = conj(M^H * B)
683   //     + if transM is 'T', newB = M * conj(B)
684   //     + if transM is 'C', newB = conj(M) * conj(B) = conj(M * B)
685 
686   ScalarArray<T> *newA, *newB;
687   ScalarArray<T>* ra = rk->a;
688   ScalarArray<T>* rb = rk->b;
689   if (transR != 'N') {
690     // if transR == 'T', we permute ra and rb; if transR == 'C', they will
691     // also have to be conjugated, but this cannot be done here because rk
692     // is const, this will be performed below.
693     std::swap(ra, rb);
694   }
695   newA = ra->copy();
696   newB = new ScalarArray<T>(transM == 'N' ? m->cols() : m->rows(), rb->cols);
697   if (transR == 'C') {
698     newA->conjugate();
699     if (transM == 'N') {
700       newB->gemm('C', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
701       newB->conjugate();
702     } else if (transM == 'T') {
703       ScalarArray<T> *conjB = rb->copy();
704       conjB->conjugate();
705       newB->gemm('N', 'N', Constants<T>::pone, &m->data, conjB, Constants<T>::zero);
706       delete conjB;
707     } else {
708       assert(transM == 'C');
709       newB->gemm('N', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
710       newB->conjugate();
711     }
712   } else {
713     if (transM == 'N') {
714       newB->gemm('T', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
715     } else if (transM == 'T') {
716       newB->gemm('N', 'N', Constants<T>::pone, &m->data, rb, Constants<T>::zero);
717     } else {
718       assert(transM == 'C');
719       ScalarArray<T> *conjB = rb->copy();
720       conjB->conjugate();
721       newB->gemm('N', 'N', Constants<T>::pone, &m->data, conjB, Constants<T>::zero);
722       newB->conjugate();
723       delete conjB;
724     }
725   }
726   RkMatrix<T>* result = new RkMatrix<T>(newA, rkRows, newB, mCols);
727   return result;
728 }
729 
730 template<typename T>
multiplyFullRk(char transM,char transR,const FullMatrix<T> * m,const RkMatrix<T> * rk)731 RkMatrix<T>* RkMatrix<T>::multiplyFullRk(char transM, char transR,
732                                          const FullMatrix<T>* m,
733                                          const RkMatrix<T>* rk) {
734   DECLARE_CONTEXT;
735   // If transM is 'N' and transR is 'N', we compute
736   //  M * A * B^T  ==> newA = M * A, newB = B
737   // We can deduce all other cases from this one:
738   //   * if transR is 'T', all we have to do is to swap A and B
739   //   * if transR is 'C', we swap A and B, but they must also
740   //     be conjugate; let us look at the different cases:
741   //     + if transM is 'N', newA = M * conj(A)
742   //     + if transM is 'T', newA = M^T * conj(A) = conj(M^H * A)
743   //     + if transM is 'C', newA = M^H * conj(A) = conj(M^T * A)
744   ScalarArray<T> *newA, *newB;
745   ScalarArray<T>* ra = rk->a;
746   ScalarArray<T>* rb = rk->b;
747   if (transR != 'N') { // permutation to transpose the matrix Rk
748     std::swap(ra, rb);
749   }
750   const IndexSet *rkCols = ((transR == 'N')? rk->cols : rk->rows);
751   const IndexSet *mRows = ((transM == 'N')? m->rows_ : m->cols_);
752 
753   newA = new ScalarArray<T>(mRows->size(), rb->cols);
754   newB = rb->copy();
755   if (transR == 'C') {
756     newB->conjugate();
757     if (transM == 'N') {
758       ScalarArray<T> *conjA = ra->copy();
759       conjA->conjugate();
760       newA->gemm('N', 'N', Constants<T>::pone, &m->data, conjA, Constants<T>::zero);
761       delete conjA;
762     } else if (transM == 'T') {
763       newA->gemm('C', 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
764       newA->conjugate();
765     } else {
766       assert(transM == 'C');
767       newA->gemm('T', 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
768       newA->conjugate();
769     }
770   } else {
771     newA->gemm(transM, 'N', Constants<T>::pone, &m->data, ra, Constants<T>::zero);
772   }
773   RkMatrix<T>* result = new RkMatrix<T>(newA, mRows, newB, rkCols);
774   return result;
775 }
776 
777 template<typename T>
multiplyRkH(char transR,char transH,const RkMatrix<T> * rk,const HMatrix<T> * h)778 RkMatrix<T>* RkMatrix<T>::multiplyRkH(char transR, char transH,
779                                       const RkMatrix<T>* rk, const HMatrix<T>* h) {
780   DECLARE_CONTEXT;
781   assert(((transR == 'N') ? *rk->cols : *rk->rows) == ((transH == 'N')? *h->rows() : *h->cols()));
782 
783   const IndexSet* rkRows = ((transR == 'N')? rk->rows : rk->cols);
784 
785   // If transR == 'N'
786   //    transM == 'N': (A*B^T)*M = A*(M^T*B)^T
787   //    transM == 'T': (A*B^T)*M^T = A*(M*B)^T
788   //    transM == 'C': (A*B^T)*M^H = A*(conj(M)*B)^T = A*conj(M*conj(B))^T
789   // If transR == 'T', we only have to swap A and B
790   // If transR == 'C', we swap A and B, then
791   //    transM == 'N': R^H*M = conj(A)*(M^T*conj(B))^T = conj(A)*conj(M^H*B)^T
792   //    transM == 'T': R^H*M^T = conj(A)*(M*conj(B))^T
793   //    transM == 'C': R^H*M^H = conj(A)*conj(M*B)^T
794   //
795   // Size of the HMatrix is n x m,
796   // So H^t size is m x n and the product is m x cols(B)
797   // and the number of columns of B is k.
798   ScalarArray<T> *newA, *newB;
799   ScalarArray<T>* ra = rk->a;
800   ScalarArray<T>* rb = rk->b;
801   if (transR != 'N') { // permutation to transpose the matrix Rk
802     std::swap(ra, rb);
803   }
804 
805   const IndexSet *newCols = ((transH == 'N' )? h->cols() : h->rows());
806 
807   newA = ra->copy();
808   newB = new ScalarArray<T>(transH == 'N' ? h->cols()->size() : h->rows()->size(), rb->cols);
809   if (transR == 'C') {
810     newA->conjugate();
811     if (transH == 'N') {
812       h->gemv('C', Constants<T>::pone, rb, Constants<T>::zero, newB);
813       newB->conjugate();
814     } else if (transH == 'T') {
815       ScalarArray<T> *conjB = rb->copy();
816       conjB->conjugate();
817       h->gemv('N', Constants<T>::pone, conjB, Constants<T>::zero, newB);
818       delete conjB;
819     } else {
820       assert(transH == 'C');
821       h->gemv('N', Constants<T>::pone, rb, Constants<T>::zero, newB);
822       newB->conjugate();
823     }
824   } else {
825     if (transH == 'N') {
826       h->gemv('T', Constants<T>::pone, rb, Constants<T>::zero, newB);
827     } else if (transH == 'T') {
828       h->gemv('N', Constants<T>::pone, rb, Constants<T>::zero, newB);
829     } else {
830       assert(transH == 'C');
831       ScalarArray<T> *conjB = rb->copy();
832       conjB->conjugate();
833       h->gemv('N', Constants<T>::pone, conjB, Constants<T>::zero, newB);
834       delete conjB;
835       newB->conjugate();
836     }
837   }
838   RkMatrix<T>* result = new RkMatrix<T>(newA, rkRows, newB, newCols);
839   return result;
840 }
841 
842 template<typename T>
multiplyHRk(char transH,char transR,const HMatrix<T> * h,const RkMatrix<T> * rk)843 RkMatrix<T>* RkMatrix<T>::multiplyHRk(char transH, char transR,
844                                       const HMatrix<T>* h, const RkMatrix<T>* rk) {
845 
846   DECLARE_CONTEXT;
847   if (rk->rank() == 0) {
848     const IndexSet* newRows = ((transH == 'N') ? h-> rows() : h->cols());
849     const IndexSet* newCols = ((transR == 'N') ? rk->cols : rk->rows);
850     return new RkMatrix<T>(NULL, newRows, NULL, newCols);
851   }
852 
853   // If transH is 'N' and transR is 'N', we compute
854   //  M * A * B^T  ==> newA = M * A, newB = B
855   // We can deduce all other cases from this one:
856   //   * if transR is 'T', all we have to do is to swap A and B
857   //   * if transR is 'C', we swap A and B, but they must also
858   //     be conjugate; let us look at the different cases:
859   //     + if transH is 'N', newA = M * conj(A)
860   //     + if transH is 'T', newA = M^T * conj(A) = conj(M^H * A)
861   //     + if transH is 'C', newA = M^H * conj(A) = conj(M^T * A)
862   ScalarArray<T> *newA, *newB;
863   ScalarArray<T>* ra = rk->a;
864   ScalarArray<T>* rb = rk->b;
865   if (transR != 'N') { // permutation to transpose the matrix Rk
866     std::swap(ra, rb);
867   }
868   const IndexSet *rkCols = ((transR == 'N')? rk->cols : rk->rows);
869   const IndexSet* newRows = ((transH == 'N')? h-> rows() : h->cols());
870 
871   newA = new ScalarArray<T>(transH == 'N' ? h->rows()->size() : h->cols()->size(), rb->cols);
872   newB = rb->copy();
873   if (transR == 'C') {
874     newB->conjugate();
875     if (transH == 'N') {
876       ScalarArray<T> *conjA = ra->copy();
877       conjA->conjugate();
878       h->gemv('N', Constants<T>::pone, conjA, Constants<T>::zero, newA);
879       delete conjA;
880     } else if (transH == 'T') {
881       h->gemv('C', Constants<T>::pone, ra, Constants<T>::zero, newA);
882       newA->conjugate();
883     } else {
884       assert(transH == 'C');
885       h->gemv('T', Constants<T>::pone, ra, Constants<T>::zero, newA);
886       newA->conjugate();
887     }
888   } else {
889     h->gemv(transH, Constants<T>::pone, ra, Constants<T>::zero, newA);
890   }
891   RkMatrix<T>* result = new RkMatrix<T>(newA, newRows, newB, rkCols);
892   return result;
893 }
894 
895 template<typename T>
multiplyRkRk(char trans1,char trans2,const RkMatrix<T> * r1,const RkMatrix<T> * r2,double epsilon)896 RkMatrix<T>* RkMatrix<T>::multiplyRkRk(char trans1, char trans2,
897                                        const RkMatrix<T>* r1, const RkMatrix<T>* r2, double epsilon) {
898   DECLARE_CONTEXT;
899   assert(((trans1 == 'N') ? *r1->cols : *r1->rows) == ((trans2 == 'N') ? *r2->rows : *r2->cols));
900   // It is possible to do the computation differently, yielding a
901   // different rank and a different amount of computation.
902   // TODO: choose the best order.
903   ScalarArray<T>* a1 = (trans1 == 'N' ? r1->a : r1->b);
904   ScalarArray<T>* b1 = (trans1 == 'N' ? r1->b : r1->a);
905   ScalarArray<T>* a2 = (trans2 == 'N' ? r2->a : r2->b);
906   ScalarArray<T>* b2 = (trans2 == 'N' ? r2->b : r2->a);
907 
908   assert(b1->rows == a2->rows); // compatibility of the multiplication
909 
910   // We want to compute the matrix a1.t^b1.a2.t^b2 and return an Rk matrix
911   // Usually, the best way is to start with tmp=t^b1.a2 which produces a 'small' matrix rank1 x rank2
912   //
913   // OLD version (default):
914   // Then we can either :
915   // - compute a1.tmp : the cost is rank1.rank2.row_a, the resulting Rk has rank rank2
916   // - compute tmp.t^b2 : the cost is rank1.rank2.col_b, the resulting Rk has rank rank1
917   // We use the solution which gives the lowest resulting rank.
918   // With this version, orthogonality is lost on one panel, it is preserved on the other.
919   //
920   // NEW version :
921   // Other solution: once we have the small matrix tmp=t^b1.a2, we can do a recompression on it for low cost
922   // using SVD + truncation. This also removes the choice above, since tmp=U.S.V is then applied on both sides
923   // This version is default, it can be deactivated by setting env. var. HMAT_OLD_RKRK
924   // With this version, orthogonality is lost on both panel.
925 
926   ScalarArray<T> tmp(r1->rank(), r2->rank(), false);
927   if (trans1 == 'C' && trans2 == 'C') {
928     tmp.gemm('T', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
929     tmp.conjugate();
930   } else if (trans1 == 'C') {
931     tmp.gemm('C', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
932   } else if (trans2 == 'C') {
933     tmp.gemm('C', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
934     tmp.conjugate();
935   } else {
936     tmp.gemm('T', 'N', Constants<T>::pone, b1, a2, Constants<T>::zero);
937   }
938 
939   ScalarArray<T> *newA=NULL, *newB=NULL;
940   static char *oldRKRK = getenv("HMAT_OLD_RKRK"); // Option to use the OLD version, without SVD-in-the-middle
941   if (!oldRKRK) {
942     // NEW version
943     ScalarArray<T>* ur = NULL;
944     ScalarArray<T>* vr = NULL;
945     // truncated SVD tmp = ur.t^vr
946     int newK = tmp.truncatedSvdDecomposition(&ur, &vr, epsilon, true);
947     if (newK > 0) {
948       /* Now compute newA = a1.ur and newB = b2.vr */
949       newA = new ScalarArray<T>(a1->rows, newK, false);
950       if (trans1 == 'C') ur->conjugate();
951       newA->gemm('N', 'N', Constants<T>::pone, a1, ur, Constants<T>::zero);
952       if (trans1 == 'C') newA->conjugate();
953       newB = new ScalarArray<T>(b2->rows, newK, false);
954       if (trans2 == 'C') vr->conjugate();
955       newB->gemm('N', 'N', Constants<T>::pone, b2, vr, Constants<T>::zero);
956       if (trans2 == 'C') newB->conjugate();
957       delete ur;
958       delete vr;
959     }
960   } else {
961     // OLD version
962     if (r1->rank() < r2->rank()) {
963       // newA = a1, newB = b2.t^tmp
964       newA = a1->copy();
965       if (trans1 == 'C') newA->conjugate();
966       newB = new ScalarArray<T>(b2->rows, r1->rank());
967       if (trans2 == 'C') {
968         newB->gemm('N', 'C', Constants<T>::pone, b2, &tmp, Constants<T>::zero);
969         newB->conjugate();
970       } else {
971         newB->gemm('N', 'T', Constants<T>::pone, b2, &tmp, Constants<T>::zero);
972       }
973     } else { // newA = a1.tmp, newB = b2
974       newA = new ScalarArray<T>(a1->rows, r2->rank());
975       if (trans1 == 'C') tmp.conjugate(); // be careful if you re-use tmp after this...
976       newA->gemm('N', 'N', Constants<T>::pone, a1, &tmp, Constants<T>::zero);
977       if (trans1 == 'C') newA->conjugate();
978       newB = b2->copy();
979       if (trans2 == 'C') newB->conjugate();
980     }
981   }
982   return new RkMatrix<T>(newA, ((trans1 == 'N') ? r1->rows : r1->cols), newB, ((trans2 == 'N') ? r2->cols : r2->rows));
983 }
984 
985 template<typename T>
computeRkRkMemorySize(char trans1,char trans2,const RkMatrix<T> * r1,const RkMatrix<T> * r2)986 size_t RkMatrix<T>::computeRkRkMemorySize(char trans1, char trans2,
987                                                 const RkMatrix<T>* r1, const RkMatrix<T>* r2)
988 {
989     ScalarArray<T>* b2 = (trans2 == 'N' ? r2->b : r2->a);
990     ScalarArray<T>* a1 = (trans1 == 'N' ? r1->a : r1->b);
991     return b2 == NULL ? 0 : b2->memorySize() +
992            a1 == NULL ? 0 : a1->rows * r2->rank() * sizeof(T);
993 }
994 
995 template<typename T>
multiplyWithDiagOrDiagInv(const HMatrix<T> * d,bool inverse,Side side)996 void RkMatrix<T>::multiplyWithDiagOrDiagInv(const HMatrix<T> * d, bool inverse, Side side) {
997   assert(*d->rows() == *d->cols());
998   assert(side == Side::RIGHT || (*rows == *d->cols()));
999   assert(side == Side::LEFT  || (*cols == *d->rows()));
1000 
1001   // extracting the diagonal
1002   Vector<T>* diag = new Vector<T>(d->cols()->size());
1003   d->extractDiagonal(diag->ptr());
1004 
1005   // left multiplication by d of b (if M<-M*D : side = RIGHT) or a (if M<-D*M : side = LEFT)
1006   ScalarArray<T>* aOrB = (side == Side::LEFT ? a : b);
1007   aOrB->multiplyWithDiagOrDiagInv(diag, inverse, Side::LEFT);
1008 
1009   delete diag;
1010 }
1011 
gemmRk(double epsilon,char transHA,char transHB,T alpha,const HMatrix<T> * ha,const HMatrix<T> * hb)1012 template<typename T> void RkMatrix<T>::gemmRk(double epsilon, char transHA, char transHB,
1013                                               T alpha, const HMatrix<T>* ha, const HMatrix<T>* hb) {
1014   DECLARE_CONTEXT;
1015   if (!ha->isLeaf() && !hb->isLeaf()) {
1016     // Recursion case
1017     int nbRows = transHA == 'N' ? ha->nrChildRow() : ha->nrChildCol() ; /* Row blocks of the product */
1018     int nbCols = transHB == 'N' ? hb->nrChildCol() : hb->nrChildRow() ; /* Col blocks of the product */
1019     int nbCom  = transHA == 'N' ? ha->nrChildCol() : ha->nrChildRow() ; /* Common dimension between A and B */
1020     int nSubRks = nbRows * nbCols;
1021     RkMatrix<T>* subRks[nSubRks];
1022     std::fill_n(subRks, nSubRks, nullptr);
1023     for (int i = 0; i < nbRows; i++) {
1024       for (int j = 0; j < nbCols; j++) {
1025         int p = i + j * nbRows;
1026         for (int k = 0; k < nbCom; k++) {
1027           // C_ij = A_ik * B_kj
1028           HMatrix<T>* a_ik = transHA == 'N' ? ha->get(i, k) : ha->get(k, i);
1029           HMatrix<T>* b_kj = transHB == 'N' ? hb->get(k, j) : hb->get(j, k);
1030           if (a_ik && b_kj) {
1031             if (subRks[p] == nullptr) {
1032               const IndexSet* subRows = transHA == 'N' ? a_ik->rows() : a_ik->cols();
1033               const IndexSet* subCols = transHB == 'N' ? b_kj->cols() : b_kj->rows();
1034               subRks[p] = new RkMatrix<T>(nullptr, subRows, nullptr, subCols);
1035             }
1036             subRks[p]->gemmRk(epsilon, transHA, transHB, alpha, a_ik, b_kj);
1037           }
1038         } // k loop
1039       } // j loop
1040     } // i loop
1041     // Reconstruction of C by adding the parts
1042     T alphaV[nSubRks];
1043     std::fill_n(alphaV, nSubRks, 1);
1044     formattedAddParts(epsilon, alphaV, subRks, nSubRks);
1045     for (int i = 0; i < nSubRks; i++) {
1046       delete subRks[i];
1047     }
1048   } else {
1049     RkMatrix<T>* rk = nullptr;
1050     // One of the product matrix is a leaf
1051     if ((ha->isLeaf() && ha->isNull()) || (hb->isLeaf() && hb->isNull())) {
1052       // Nothing to do
1053     } else if (ha->isRkMatrix() || hb->isRkMatrix()) {
1054       rk = HMatrix<T>::multiplyRkMatrix(epsilon, transHA, transHB, ha, hb);
1055     } else {
1056       assert(ha->isFullMatrix() || hb->isFullMatrix());
1057       FullMatrix<T>* fullMat = HMatrix<T>::multiplyFullMatrix(transHA, transHB, ha, hb);
1058       if(fullMat) {
1059         rk = truncatedSvd(fullMat, epsilon); // TODO compress with something else than SVD
1060         delete fullMat;
1061       }
1062     }
1063     if(rk) {
1064       if(rank() == 0) {
1065         // save memory by not allocating a temporary Rk
1066         rk->scale(alpha);
1067         swap(*rk);
1068       } else
1069         axpy(epsilon, alpha, rk);
1070       delete rk;
1071     }
1072   }
1073 }
1074 
copy(const RkMatrix<T> * o)1075 template<typename T> void RkMatrix<T>::copy(const RkMatrix<T>* o) {
1076   delete a;
1077   delete b;
1078   rows = o->rows;
1079   cols = o->cols;
1080   a = (o->a ? o->a->copy() : NULL);
1081   b = (o->b ? o->b->copy() : NULL);
1082 }
1083 
copy() const1084 template<typename T> RkMatrix<T>* RkMatrix<T>::copy() const {
1085   RkMatrix<T> *result = new RkMatrix<T>(NULL, rows, NULL, cols);
1086   result->copy(this);
1087   return result;
1088 }
1089 
1090 
checkNan() const1091 template<typename T> void RkMatrix<T>::checkNan() const {
1092   if (rank() == 0) {
1093     return;
1094   }
1095   a->checkNan();
1096   b->checkNan();
1097 }
1098 
conjugate()1099 template<typename T> void RkMatrix<T>::conjugate() {
1100   if (a) a->conjugate();
1101   if (b) b->conjugate();
1102 }
1103 
get(int i,int j) const1104 template<typename T> T RkMatrix<T>::get(int i, int j) const {
1105   return a->dot_aibj(i, *b, j);
1106 }
1107 
writeArray(hmat_iostream writeFunc,void * userData) const1108 template<typename T> void RkMatrix<T>::writeArray(hmat_iostream writeFunc, void * userData) const{
1109   a->writeArray(writeFunc, userData);
1110   b->writeArray(writeFunc, userData);
1111 }
1112 
1113 template <typename T>
1114 bool (*RkMatrix<T>::formatedAddPartsHook)(RkMatrix<T> *me, double epsilon, const T *alpha,
1115                                                const RkMatrix<T> *const *parts,
1116                                                const int n) = NULL;
1117 
1118 // Templates declaration
1119 template class RkMatrix<S_t>;
1120 template class RkMatrix<D_t>;
1121 template class RkMatrix<C_t>;
1122 template class RkMatrix<Z_t>;
1123 
1124 }  // end namespace hmat
1125