1 /*
2   HMat-OSS (HMatrix library, open source software)
3 
4   Copyright (C) 2014-2015 Airbus Group SAS
5 
6   This program is free software; you can redistribute it and/or
7   modify it under the terms of the GNU General Public License
8   as published by the Free Software Foundation; either version 2
9   of the License, or (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19 
20   http://github.com/jeromerobert/hmat-oss
21 */
22 
23 #ifndef _C_WRAPPING_HPP
24 #define _C_WRAPPING_HPP
25 
26 #include <string>
27 #include <cstring>
28 
29 #include "common/context.hpp"
30 #include "common/my_assert.h"
31 #include "full_matrix.hpp"
32 #include "h_matrix.hpp"
33 #include "uncompressed_values.hpp"
34 #include "serialization.hpp"
35 #include "hmat_cpp_interface.hpp"
36 #include "disable_threading.hpp"
37 
38 namespace
39 {
40 template<typename T, template <typename> class E>
create_empty_hmatrix_admissibility(const hmat_cluster_tree_t * rows_tree,const hmat_cluster_tree_t * cols_tree,int lower_sym,hmat_admissibility_t * condition)41 hmat_matrix_t * create_empty_hmatrix_admissibility(
42   const hmat_cluster_tree_t* rows_tree,
43   const hmat_cluster_tree_t* cols_tree, int lower_sym,
44   hmat_admissibility_t* condition)
45 {
46   DECLARE_CONTEXT;
47     hmat::SymmetryFlag sym = lower_sym ? hmat::kLowerSymmetric : hmat::kNotSymmetric;
48     hmat::IEngine<T>* engine = new E<T>();
49     return (hmat_matrix_t*) new hmat::HMatInterface<T>(
50             engine,
51             reinterpret_cast<const hmat::ClusterTree*>(rows_tree),
52             reinterpret_cast<const hmat::ClusterTree*>(cols_tree),
53             sym, (hmat::AdmissibilityCondition*)condition);
54 }
55 
56 template<typename T, template <typename> class E>
assemble_generic(hmat_matrix_t * matrix,hmat_assemble_context_t * ctx)57 int assemble_generic(hmat_matrix_t* matrix, hmat_assemble_context_t * ctx) {
58     DECLARE_CONTEXT;
59     hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)matrix;
60     bool assembleOnly = ctx->factorization == hmat_factorization_none;
61     hmat::SymmetryFlag sf = ctx->lower_symmetric ? hmat::kLowerSymmetric : hmat::kNotSymmetric;
62     try {
63         if (ctx->lower_symmetric) {
64           HMAT_ASSERT(hmat->engine().hmat->rowsTree() == hmat->engine().hmat->colsTree());
65         }
66         HMAT_ASSERT_MSG(ctx->compression, "No compression algorithm defined in hmat_assemble_context_t");
67         hmat::CompressionAlgorithm* compression = (hmat::CompressionAlgorithm*)ctx->compression;
68         if(ctx->assembly != NULL) {
69             HMAT_ASSERT(ctx->block_compute == NULL && ctx->advanced_compute == NULL && ctx->simple_compute == NULL);
70             hmat::Assembly<T> * cppAssembly = (hmat::Assembly<T> *)ctx->assembly;
71             hmat->assemble(*cppAssembly, sf, ctx->progress);
72         } else if(ctx->block_compute != NULL || ctx->advanced_compute != NULL) {
73             HMAT_ASSERT(ctx->simple_compute == NULL && ctx->assembly == NULL);
74             HMAT_ASSERT(ctx->prepare != NULL);
75             hmat::BlockFunction<T> blockFunction(hmat->rows(), hmat->cols(),
76                 ctx->user_context, ctx->prepare, ctx->block_compute, ctx->advanced_compute);
77             hmat::AssemblyFunction<T, hmat::BlockFunction> * f =
78                 new hmat::AssemblyFunction<T, hmat::BlockFunction>(blockFunction, compression);
79             hmat->assemble(*f, sf, true, ctx->progress, true);
80         } else if(ctx->simple_compute != NULL) {
81             HMAT_ASSERT(ctx->block_compute == NULL && ctx->advanced_compute == NULL && ctx->assembly == NULL);
82             hmat::AssemblyFunction<T, hmat::SimpleFunction> * f =
83                 new hmat::AssemblyFunction<T, hmat::SimpleFunction>(
84                 hmat::SimpleFunction<T>(ctx->simple_compute, ctx->user_context), compression);
85             hmat->assemble(*f, sf, true, ctx->progress, true);
86         } else
87           HMAT_ASSERT_MSG(0, "No valid assembly method in assemble_generic()");
88 
89         if(!assembleOnly)
90             hmat->factorize(hmat::convert_int_to_factorization(ctx->factorization), ctx->progress);
91     } catch (const std::exception& e) {
92         fprintf(stderr, "%s\n", e.what());
93         return 1;
94     }
95     return 0;
96 }
97 
98 template<typename T, template <typename> class E>
copy(hmat_matrix_t * holder)99 hmat_matrix_t* copy(hmat_matrix_t* holder) {
100   DECLARE_CONTEXT;
101   return (hmat_matrix_t*) ((hmat::HMatInterface<T>*) holder)->copy();
102 }
103 
104 template<typename T, template <typename> class E>
copy_struct(hmat_matrix_t * holder)105 hmat_matrix_t* copy_struct(hmat_matrix_t* holder) {
106   DECLARE_CONTEXT;
107   return (hmat_matrix_t*) ((hmat::HMatInterface<T>*) holder)->copy(true);
108 }
109 
110 template<typename T, template <typename> class E>
destroy(hmat_matrix_t * holder)111 int destroy(hmat_matrix_t* holder) {
112   DECLARE_CONTEXT;
113   delete (hmat::HMatInterface<T>*)(holder);
114   return 0;
115 }
116 
117 template<typename T, template <typename> class E>
destroy_child(hmat_matrix_t * holder)118 int destroy_child(hmat_matrix_t* holder) {
119   DECLARE_CONTEXT;
120   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
121   hmat->setHMatrix();
122   delete hmat;
123   return 0;
124 }
125 
126 template<typename T, template <typename> class E>
inverse(hmat_matrix_t * holder)127 int inverse(hmat_matrix_t* holder) {
128   DECLARE_CONTEXT;
129   try {
130       ((hmat::HMatInterface<T>*) holder)->inverse();
131   } catch (const std::exception& e) {
132       fprintf(stderr, "%s\n", e.what());
133       return 1;
134   }
135   return 0;
136 }
137 
138 template<typename T, template <typename> class E>
factorize_generic(hmat_matrix_t * holder,hmat_factorization_context_t * ctx)139 int factorize_generic(hmat_matrix_t* holder, hmat_factorization_context_t * ctx) {
140     DECLARE_CONTEXT;
141     hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
142     try {
143         hmat->factorize(hmat::convert_int_to_factorization(ctx->factorization), ctx->progress);
144     } catch (const std::exception& e) {
145         fprintf(stderr, "%s\n", e.what());
146         return 1;
147     }
148     return 0;
149 }
150 
151 template<typename T, template <typename> class E>
factor(hmat_matrix_t * holder,hmat_factorization_t t)152 int factor(hmat_matrix_t* holder, hmat_factorization_t t) {
153   DECLARE_CONTEXT;
154     hmat_factorization_context_t ctx;
155     hmat_factorization_context_init(&ctx);
156     ctx.factorization = t;
157     return factorize_generic<T, E>(holder, &ctx);
158 }
159 
160 template<typename T, template <typename> class E>
finalize()161 int finalize() {
162   DECLARE_CONTEXT;
163   E<T>::finalize();
164   return 0;
165 }
166 
167 template<typename T, template <typename> class E>
full_gemm(char transA,char transB,int mc,int nc,void * c,void * alpha,void * a,hmat_matrix_t * holder,void * beta)168 int full_gemm(char transA, char transB, int mc, int nc, void* c,
169                              void* alpha, void* a, hmat_matrix_t * holder, void* beta) {
170   DECLARE_CONTEXT;
171 
172   try {
173       const hmat::HMatInterface<T>* b = (hmat::HMatInterface<T>*)holder;
174       hmat::ScalarArray<T> matC((T*)c, mc, nc);
175       hmat::ScalarArray<T>* matA = NULL;
176       const hmat::ClusterData* bDataRows = (transB == 'N' ? b->rows(): b->cols());
177       const hmat::ClusterData* bDataCols = (transB == 'N' ? b->cols(): b->rows());
178       hmat::reorderVector(&matC, bDataCols->indices(), 1);
179       if (transA == 'N') {
180         matA = new hmat::ScalarArray<T>((T*)a, mc, bDataRows->size());
181         hmat::reorderVector(matA, bDataRows->indices(), 1);
182       } else {
183         matA = new hmat::ScalarArray<T>((T*)a, bDataRows->size(), mc);
184         hmat::reorderVector(matA, bDataRows->indices(), 0);
185       }
186       hmat::HMatInterface<T>::gemm(matC, transA, transB, *((T*)alpha), *matA, *b, *((T*)beta));
187       hmat::restoreVectorOrder(&matC, bDataCols->indices(), 1);
188       if (transA == 'N') {
189           hmat::restoreVectorOrder(matA, bDataRows->indices(), 1);
190       } else {
191           hmat::restoreVectorOrder(matA, bDataRows->indices(), 0);
192       }
193       delete matA;
194   } catch (const std::exception& e) {
195       fprintf(stderr, "%s\n", e.what());
196       return 1;
197   }
198   return 0;
199 }
200 
201 template<typename T, template <typename> class E>
gemm(char trans_a,char trans_b,void * alpha,hmat_matrix_t * holder,hmat_matrix_t * holder_b,void * beta,hmat_matrix_t * holder_c)202 int gemm(char trans_a, char trans_b, void *alpha, hmat_matrix_t * holder,
203                    hmat_matrix_t * holder_b, void *beta, hmat_matrix_t * holder_c) {
204   DECLARE_CONTEXT;
205   hmat::HMatInterface<T>* hmat_a = (hmat::HMatInterface<T>*)holder;
206   hmat::HMatInterface<T>* hmat_b = (hmat::HMatInterface<T>*)holder_b;
207   hmat::HMatInterface<T>* hmat_c = (hmat::HMatInterface<T>*)holder_c;
208   try {
209       hmat_c->gemm(trans_a, trans_b, *((T*)alpha), hmat_a, hmat_b, *((T*)beta));
210   } catch (const std::exception& e) {
211       fprintf(stderr, "%s\n", e.what());
212       return 1;
213   }
214   return 0;
215 }
216 
217 template<typename T, template <typename> class E>
axpy(void * a,hmat_matrix_t * x,hmat_matrix_t * y)218 int axpy(void *a, hmat_matrix_t * x, hmat_matrix_t * y) {
219   DECLARE_CONTEXT;
220   DISABLE_THREADING_IN_BLOCK;
221   hmat::HMatInterface<T>* hmat_x = reinterpret_cast<hmat::HMatInterface<T>*>(x);
222   hmat::HMatInterface<T>* hmat_y = reinterpret_cast<hmat::HMatInterface<T>*>(y);
223   try {
224       hmat_y->engine().hmat->axpy(*((T*)a), hmat_x->engine().hmat);
225   } catch (const std::exception& e) {
226       fprintf(stderr, "%s\n", e.what());
227       return 1;
228   }
229   return 0;
230 }
231 
232 template<typename T, template <typename> class E>
gemv(char trans_a,void * alpha,hmat_matrix_t * holder,void * vec_b,void * beta,void * vec_c,int nrhs)233 int gemv(char trans_a, void* alpha, hmat_matrix_t * holder, void* vec_b,
234                    void* beta, void* vec_c, int nrhs) {
235   DECLARE_CONTEXT;
236   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
237   const hmat::ClusterData* bData = (trans_a == 'N' ? hmat->cols(): hmat->rows());
238   const hmat::ClusterData* cData = (trans_a == 'N' ? hmat->rows(): hmat->cols());
239   try {
240       hmat::ScalarArray<T> mb((T*) vec_b, bData->size(), nrhs);
241       hmat::ScalarArray<T> mc((T*) vec_c, cData->size(), nrhs);
242       hmat::reorderVector(&mb, bData->indices(), 0);
243       hmat::reorderVector(&mc, cData->indices(), 0);
244       hmat->gemv(trans_a, *((T*)alpha), mb, *((T*)beta), mc);
245       hmat::restoreVectorOrder(&mb, bData->indices(), 0);
246       hmat::restoreVectorOrder(&mc, cData->indices(), 0);
247   } catch (const std::exception& e) {
248       fprintf(stderr, "%s\n", e.what());
249       return 1;
250   }
251   return 0;
252 }
253 
254 template<typename T, template <typename> class E>
gemm_scalar(char trans_a,void * alpha,hmat_matrix_t * holder,void * vec_b,void * beta,void * vec_c,int nrhs)255 int gemm_scalar( char trans_a, void* alpha, hmat_matrix_t * holder, void* vec_b,
256 		 void* beta, void* vec_c, int nrhs ) {
257   DECLARE_CONTEXT;
258   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
259   const hmat::ClusterData* bData = (trans_a == 'N' ? hmat->cols(): hmat->rows());
260   const hmat::ClusterData* cData = (trans_a == 'N' ? hmat->rows(): hmat->cols());
261   try {
262       hmat::ScalarArray<T> mb((T*) vec_b, bData->size(), nrhs);
263       hmat::ScalarArray<T> mc((T*) vec_c, cData->size(), nrhs);
264 
265       hmat->gemm_scalar(trans_a, *((T*)alpha), mb, *((T*)beta), mc);
266   } catch (const std::exception& e) {
267       fprintf(stderr, "%s\n", e.what());
268       return 1;
269   }
270   return 0;
271 }
272 
is_trans(char trans)273 inline bool is_trans(char trans) {
274     return trans == 'T' || trans == 'C';
275 }
276 
is_conj(char trans)277 inline bool is_conj(char trans) {
278     return trans == 'J' || trans == 'C';
279 }
280 
switch_flag_trans(char trans)281 inline char switch_flag_trans(char trans) {
282     switch (trans) {
283     case 'N': return 'T';
284     case 'T': return 'N';
285     case 'C': return 'J';
286     case 'J': return 'C';
287     default: HMAT_ASSERT(false);
288     }
289 }
290 
switch_flag_conj(char trans)291 inline char switch_flag_conj(char trans) {
292     switch (trans) {
293     case 'N': return 'J';
294     case 'T': return 'C';
295     case 'C': return 'T';
296     case 'J': return 'N';
297     default: HMAT_ASSERT(false);
298     }
299 }
300 
301 template<typename T, template <typename> class E>
gemm_dense(char trans_b,char trans_x,char side,void * alpha,hmat_matrix_t * holder,void * vec_x,void * beta,void * vec_y,int nrhs)302 int gemm_dense(char trans_b, char trans_x, char side, void* alpha, hmat_matrix_t* holder,
303                void* vec_x, void* beta, void* vec_y, int nrhs) {
304   char trans_y = 'N';
305   T alphaT = *((T*)alpha);
306   T betaT = *((T*)beta);
307   alpha = &alphaT;
308   beta = &betaT;
309   if (side == 'R') {
310       // Y <- alpha X * B + beta Y <=>  Y^t <- alpha B^t X^t + beta Y^t or Y^H <- bar(alpha) B^H * X^H + bar(beta) Y^H
311       if (trans_b == 'C') {
312           trans_b = 'N';
313           trans_x = switch_flag_conj(switch_flag_trans(trans_x));
314           trans_y = 'C';
315           alphaT = hmat::conj(alphaT);
316           betaT = hmat::conj(betaT);
317       } else {
318           trans_b = switch_flag_trans(trans_b);
319           trans_x = switch_flag_trans(trans_x);
320           trans_y = 'T';
321       }
322   }
323 
324   DECLARE_CONTEXT;
325   DISABLE_THREADING_IN_BLOCK;
326 
327   // Now side='L': op(Y) <- alpha op(B) op(X) + beta op(Y)
328   const hmat::HMatInterface<T>* b = (hmat::HMatInterface<T>*)holder;
329   const hmat::IndexSet* bDataRows = !is_trans(trans_b) ? b->rows(): b->cols();
330   const hmat::IndexSet* bDataCols = !is_trans(trans_b) ? b->cols(): b->rows();
331   try {
332       hmat::ScalarArray<T>* mx = !is_trans(trans_x) ?
333           new hmat::ScalarArray<T>((T*) vec_x, bDataCols->size(), nrhs) :
334           new hmat::ScalarArray<T>((T*) vec_x, nrhs, bDataCols->size());
335       hmat::ScalarArray<T>* my =  !is_trans(trans_y) ?
336           new hmat::ScalarArray<T>((T*) vec_y, bDataRows->size(), nrhs) :
337           new hmat::ScalarArray<T>((T*) vec_y, nrhs, bDataRows->size());
338 
339       // Apply transformations on x and y
340       if (is_trans(trans_x))
341           mx->transpose();
342       if (is_conj(trans_x))
343           mx->conjugate();
344       if (is_trans(trans_y))
345           my->transpose();
346       if (is_conj(trans_y))
347           my->conjugate();
348 
349       b->gemv(trans_b, *((T*)alpha), *mx, *((T*)beta), *my);
350 
351       // Apply inverse transformations on x and y
352       if (is_trans(trans_x))
353           mx->transpose();
354       if (is_trans(trans_y))
355           my->transpose();
356       if (is_conj(trans_y))
357           my->conjugate();
358 
359       delete mx;
360       delete my;
361   } catch (const std::exception& e) {
362       fprintf(stderr, "%s\n", e.what());
363       return 1;
364   }
365   return 0;
366 }
367 
368 template<typename T, template <typename> class E>
trsm(char side,char uplo,char transa,char diag,int m,int n,void * alpha,hmat_matrix_t * A,int is_b_hmat,void * B)369 int trsm( char side, char uplo, char transa, char diag, int m, int n,
370 	  void *alpha, hmat_matrix_t *A, int is_b_hmat, void *B )
371 {
372   DECLARE_CONTEXT;
373   hmat::HMatInterface<T>* hmatA = (hmat::HMatInterface<T>*)A;
374 
375   try {
376       if ( is_b_hmat ) {
377           hmat::HMatInterface<T>* hmatB = (hmat::HMatInterface<T>*)B;
378           hmatA->trsm( side, uplo, transa, diag, *((T*)alpha), hmatB );
379       }
380       else {
381           bool isleft = (side == 'l') || (side == 'L');
382           hmat::ScalarArray<T> mB( (T*)B, (isleft ? m : n), (isleft ? n : m ) );
383           hmatA->trsm( side, uplo, transa, diag, *((T*)alpha), mB );
384       }
385   } catch (const std::exception& e) {
386       fprintf(stderr, "%s\n", e.what());
387       return 1;
388   }
389   return 0;
390 }
391 
392 template<typename T, template <typename> class E>
add_identity(hmat_matrix_t * holder,void * alpha)393 int add_identity(hmat_matrix_t* holder, void *alpha) {
394   DECLARE_CONTEXT;
395   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
396   try {
397       hmat->addIdentity(*((T*)alpha));
398   } catch (const std::exception& e) {
399       fprintf(stderr, "%s\n", e.what());
400       return 1;
401   }
402   return 0;
403 }
404 
405 template<typename T, template <typename> class E>
init()406 int init() {
407   DECLARE_CONTEXT;
408   return E<T>::init();
409 }
410 
411 template<typename T, template <typename> class E>
norm(hmat_matrix_t * holder)412 double norm(hmat_matrix_t* holder) {
413   DECLARE_CONTEXT;
414   return ((hmat::HMatInterface<T>*)holder)->norm();
415 }
416 
417 template<typename T, template <typename> class E>
scale(void * alpha,hmat_matrix_t * holder)418 int scale(void *alpha, hmat_matrix_t* holder) {
419   DECLARE_CONTEXT;
420   try {
421       ((hmat::HMatInterface<T>*)holder)->scale(*((T*)alpha));
422   } catch (const std::exception& e) {
423       fprintf(stderr, "%s\n", e.what());
424       return 1;
425   }
426   return 0;
427 }
428 
429 template<typename T, template <typename> class E>
truncate(hmat_matrix_t * holder)430 int truncate(hmat_matrix_t* holder) {
431   DECLARE_CONTEXT;
432   try {
433       ((hmat::HMatInterface<T>*)holder)->truncate();
434   } catch (const std::exception& e) {
435       fprintf(stderr, "%s\n", e.what());
436       return 1;
437   }
438   return 0;
439 }
440 
441 template<typename T, template <typename> class E>
vector_reorder(void * vec_b,const hmat_cluster_tree_t * rows_ct,int rows,const hmat_cluster_tree_t * cols_ct,int cols)442 int vector_reorder(void* vec_b, const hmat_cluster_tree_t *rows_ct, int rows, const hmat_cluster_tree_t *cols_ct, int cols) {
443   DECLARE_CONTEXT;
444   try {
445       HMAT_ASSERT_MSG(rows_ct != NULL || rows != 0, "either row cluster tree or rows must be non null");
446       HMAT_ASSERT_MSG(cols_ct != NULL || cols != 0, "either col cluster tree or cols must be non null");
447       const hmat::ClusterTree *clusterTreeRows = reinterpret_cast<const hmat::ClusterTree*>(rows_ct);
448       const hmat::ClusterTree *clusterTreeCols = reinterpret_cast<const hmat::ClusterTree*>(cols_ct);
449       int nrows = clusterTreeRows == NULL ? rows : clusterTreeRows->data.size();
450       int ncols = clusterTreeCols == NULL ? cols : clusterTreeCols->data.size();
451       hmat::ScalarArray<T> mb((T*) vec_b, nrows, ncols);
452       if (clusterTreeRows) {
453         hmat::reorderVector(&mb, clusterTreeRows->data.indices(), 0);
454       }
455       if (clusterTreeCols) {
456         hmat::reorderVector(&mb, clusterTreeCols->data.indices(), 1);
457       }
458   } catch (const std::exception& e) {
459       fprintf(stderr, "%s\n", e.what());
460       return 1;
461   }
462   return 0;
463 }
464 
465 template<typename T, template <typename> class E>
vector_restore(void * vec_b,const hmat_cluster_tree_t * rows_ct,int rows,const hmat_cluster_tree_t * cols_ct,int cols)466 int vector_restore(void* vec_b, const hmat_cluster_tree_t *rows_ct, int rows, const hmat_cluster_tree_t *cols_ct, int cols) {
467   DECLARE_CONTEXT;
468   try {
469       HMAT_ASSERT_MSG(rows_ct != NULL || rows != 0, "either row cluster tree or rows must be non null");
470       HMAT_ASSERT_MSG(cols_ct != NULL || cols != 0, "either col cluster tree or cols must be non null");
471       const hmat::ClusterTree *clusterTreeRows = reinterpret_cast<const hmat::ClusterTree*>(rows_ct);
472       const hmat::ClusterTree *clusterTreeCols = reinterpret_cast<const hmat::ClusterTree*>(cols_ct);
473       int nrows = clusterTreeRows == NULL ? rows : clusterTreeRows->data.size();
474       int ncols = clusterTreeCols == NULL ? cols : clusterTreeCols->data.size();
475       hmat::ScalarArray<T> mb((T*) vec_b, nrows, ncols);
476       if (clusterTreeRows) {
477         hmat::restoreVectorOrder(&mb, clusterTreeRows->data.indices(), 0);
478       }
479       if (clusterTreeCols) {
480         hmat::restoreVectorOrder(&mb, clusterTreeCols->data.indices(), 1);
481       }
482   } catch (const std::exception& e) {
483       fprintf(stderr, "%s\n", e.what());
484       return 1;
485   }
486   return 0;
487 }
488 
489 template<typename T, template <typename> class E>
solve_mat(hmat_matrix_t * hmat,hmat_matrix_t * hmatB)490 int solve_mat(hmat_matrix_t* hmat, hmat_matrix_t* hmatB) {
491   DECLARE_CONTEXT;
492   try {
493       ((hmat::HMatInterface<T>*)hmat)->solve(*(hmat::HMatInterface<T>*)hmatB);
494   } catch (const std::exception& e) {
495       fprintf(stderr, "%s\n", e.what());
496       return 1;
497   }
498   return 0;
499 }
500 
501 template<typename T, template <typename> class E>
solve_systems(hmat_matrix_t * holder,void * b,int nrhs)502 int solve_systems(hmat_matrix_t* holder, void* b, int nrhs) {
503   DECLARE_CONTEXT;
504   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
505   try {
506       hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
507       hmat::reorderVector<T>(&mb, hmat->cols()->indices(), 0);
508       hmat->solve(mb);
509       hmat::restoreVectorOrder<T>(&mb, hmat->cols()->indices(), 0);
510   } catch (const std::exception& e) {
511       fprintf(stderr, "%s\n", e.what());
512       return 1;
513   }
514   return 0;
515 }
516 
517 template<typename T, template <typename> class E>
solve_dense(hmat_matrix_t * holder,void * b,int nrhs)518 int solve_dense(hmat_matrix_t* holder, void* b, int nrhs) {
519   DECLARE_CONTEXT;
520   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
521   try {
522       hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
523       hmat->solve(mb);
524   } catch (const std::exception& e) {
525       fprintf(stderr, "%s\n", e.what());
526       return 1;
527   }
528   return 0;
529 }
530 
531 template<typename T, template <typename> class E>
transpose(hmat_matrix_t * hmat)532 int transpose(hmat_matrix_t* hmat) {
533   DECLARE_CONTEXT;
534   try {
535       ((hmat::HMatInterface<T>*)hmat)->transpose();
536   } catch (const std::exception& e) {
537       fprintf(stderr, "%s\n", e.what());
538       return 1;
539   }
540   return 0;
541 }
542 
543 template<typename T, template <typename> class E>
hmat_get_info(hmat_matrix_t * holder,hmat_info_t * info)544 int hmat_get_info(hmat_matrix_t* holder, hmat_info_t* info) {
545   DECLARE_CONTEXT;
546   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
547   try {
548       hmat->info(*info);
549   } catch (const std::exception& e) {
550       fprintf(stderr, "%s\n", e.what());
551       return 1;
552   }
553   return 0;
554 }
555 
556 template<typename T, template <typename> class E>
hmat_dump_info(hmat_matrix_t * holder,char * prefix)557 int hmat_dump_info(hmat_matrix_t* holder, char* prefix) {
558   DECLARE_CONTEXT;
559   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
560   try {
561       std::string filejson(prefix);
562       filejson += ".json";
563       hmat->dumpTreeToFile( filejson );
564   } catch (const std::exception& e) {
565       fprintf(stderr, "%s\n", e.what());
566       return 1;
567   }
568   return 0;
569 }
570 
571 template<typename T, template <typename> class E>
get_cluster_trees(hmat_matrix_t * holder,const hmat_cluster_tree_t ** rows,const hmat_cluster_tree_t ** cols)572 int get_cluster_trees(hmat_matrix_t* holder, const hmat_cluster_tree_t ** rows, const hmat_cluster_tree_t ** cols) {
573   DECLARE_CONTEXT;
574   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
575   try {
576       if (rows)
577         *rows = static_cast<const hmat_cluster_tree_t*>(static_cast<const void*>(hmat->engine().hmat->rowsTree()));
578       if (cols)
579         *cols = static_cast<const hmat_cluster_tree_t*>(static_cast<const void*>(hmat->engine().hmat->colsTree()));
580   } catch (const std::exception& e) {
581       fprintf(stderr, "%s\n", e.what());
582       return 1;
583   }
584   return 0;
585 }
586 
587 template<typename T, template <typename> class E>
set_cluster_trees(hmat_matrix_t * holder,const hmat_cluster_tree_t * rows,const hmat_cluster_tree_t * cols)588 int set_cluster_trees(hmat_matrix_t* holder, const hmat_cluster_tree_t * rows, const hmat_cluster_tree_t * cols) {
589   DECLARE_CONTEXT;
590   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
591   try {
592       hmat->engine().hmat->setClusterTrees(
593         reinterpret_cast<const hmat::ClusterTree*>(rows),
594         reinterpret_cast<const hmat::ClusterTree*>(cols));
595   } catch (const std::exception& e) {
596       fprintf(stderr, "%s\n", e.what());
597       return 1;
598   }
599   return 0;
600 }
601 
602 template<typename T, template <typename> class E>
own_cluster_trees(hmat_matrix_t * holder,int owns_row,int owns_col)603 void own_cluster_trees(hmat_matrix_t* holder, int owns_row, int owns_col)
604 {
605   DECLARE_CONTEXT;
606   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
607   hmat->engine().hmat->ownClusterTrees(owns_row != 0, owns_col != 0);
608 }
609 
610 template<typename T, template <typename> class E>
set_low_rank_epsilon(hmat_matrix_t * holder,double epsilon)611 void set_low_rank_epsilon(hmat_matrix_t* holder, double epsilon)
612 {
613   DECLARE_CONTEXT;
614   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
615   hmat->engine().hmat->lowRankEpsilon(epsilon);
616 }
617 
618 template<typename T, template <typename> class E>
extract_diagonal(hmat_matrix_t * holder,void * diag,int size)619 int extract_diagonal(hmat_matrix_t* holder, void* diag, int size)
620 {
621   DECLARE_CONTEXT;
622   (void)size; //for API compatibility
623   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
624   try {
625       hmat->engine().hmat->extractDiagonal(static_cast<T*>(diag));
626       hmat::ScalarArray<T> permutedDiagonal(static_cast<T*>(diag), hmat->cols()->size(), 1);
627       hmat::restoreVectorOrder(&permutedDiagonal, hmat->cols()->indices(), 0);
628   } catch (const std::exception& e) {
629       fprintf(stderr, "%s\n", e.what());
630       return 1;
631   }
632   return 0;
633 }
634 
635 template<typename T, template <typename> class E>
solve_lower_triangular(hmat_matrix_t * holder,int transpose,void * b,int nrhs)636 int solve_lower_triangular(hmat_matrix_t* holder, int transpose, void* b, int nrhs)
637 {
638   DECLARE_CONTEXT;
639   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
640   hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
641   try {
642       if (transpose)
643         hmat::reorderVector<T>(&mb, hmat->rows()->indices(), 0);
644       else
645         hmat::reorderVector<T>(&mb, hmat->cols()->indices(), 0);
646       hmat->solveLower(mb, transpose);
647       if (transpose)
648         hmat::restoreVectorOrder<T>(&mb, hmat->rows()->indices(), 0);
649       else
650         hmat::restoreVectorOrder<T>(&mb, hmat->cols()->indices(), 0);
651   } catch (const std::exception& e) {
652       fprintf(stderr, "%s\n", e.what());
653       return 1;
654   }
655   return 0;
656 }
657 
658 template<typename T, template <typename> class E>
solve_lower_triangular_dense(hmat_matrix_t * holder,int transpose,void * b,int nrhs)659 int solve_lower_triangular_dense(hmat_matrix_t* holder, int transpose, void* b, int nrhs)
660 {
661   DECLARE_CONTEXT;
662   hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
663   hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
664   try {
665       hmat->solveLower(mb, transpose);
666   } catch (const std::exception& e) {
667       fprintf(stderr, "%s\n", e.what());
668       return 1;
669   }
670   return 0;
671 }
672 
673 template <typename T, template <typename> class E>
get_child(hmat_matrix_t * hmatrix,int i,int j)674 hmat_matrix_t *get_child( hmat_matrix_t *hmatrix, int i, int j ) {
675     DECLARE_CONTEXT;
676     hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)hmatrix;
677 
678     hmat::HMatrix<T> *m = hmat->get( i, j );
679     hmat::IEngine<T>* engine = new E<T>();
680     hmat::HMatInterface<T> *r = new hmat::HMatInterface<T>( engine, m, hmat->factorization() );
681     return (hmat_matrix_t*) r;
682 }
683 
684 template <typename T, template <typename> class E>
get_block(struct hmat_get_values_context_t * ctx)685 int get_block(struct hmat_get_values_context_t *ctx) {
686   DECLARE_CONTEXT;
687   DISABLE_THREADING_IN_BLOCK;
688     hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)ctx->matrix;
689     try {
690         hmat::IndexSet rows(ctx->row_offset, ctx->row_size);
691         hmat::IndexSet cols(ctx->col_offset, ctx->col_size);
692         typename E<T>::UncompressedBlock view;
693         const E<T>& engine = dynamic_cast<const E<T>&>(hmat->engine());
694         view.uncompress(engine.getHandle(), rows, cols, (T*)ctx->values);
695         hmat::HMatrix<T>* compressed = hmat->engine().hmat;
696         // Symmetrize values when requesting a full symmetric matrix
697         if (compressed->isLower &&
698             ctx->row_offset == 0 && ctx->col_offset == 0 &&
699             ctx->row_size == compressed->rows()->size() && ctx->col_size == compressed->cols()->size())
700         {
701           T* ptr = static_cast<T*>(ctx->values);
702           for (int i = 0; i < ctx->row_size; i++) {
703             for (int j = i + 1; j < ctx->col_size; j++) {
704               ptr[j*ctx->row_size + i] = ptr[i*ctx->row_size + j];
705             }
706           }
707         }
708         if (ctx->renumber_rows)
709             view.renumberRows();
710         ctx->col_indices = view.colsNumbering();
711         ctx->row_indices= view.rowsNumbering();
712     } catch (const std::exception& e) {
713         fprintf(stderr, "%s\n", e.what());
714         return 1;
715     }
716     return 0;
717 }
718 
719 template <typename T, template <typename> class E>
get_values(struct hmat_get_values_context_t * ctx)720 int get_values(struct hmat_get_values_context_t *ctx) {
721   DECLARE_CONTEXT;
722     // No need to call DISABLE_THREADING_IN_BLOCK here, there is no BLAS call
723     hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)ctx->matrix;
724     try {
725         typename E<T>::UncompressedValues view;
726         const E<T>& engine = reinterpret_cast<const E<T>&>(hmat->engine());
727         view.uncompress(engine.getHandle(),
728                         ctx->row_indices, ctx->row_size,
729                         ctx->col_indices, ctx->col_size,
730                         (T*)ctx->values);
731     } catch (const std::exception& e) {
732         fprintf(stderr, "%s\n", e.what());
733         return 1;
734     }
735     return 0;
736 }
737 
738 template <typename T, template <typename> class E>
walk(hmat_matrix_t * holder,hmat_procedure_t * proc)739 int walk(hmat_matrix_t* holder, hmat_procedure_t* proc) {
740   DECLARE_CONTEXT;
741     hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *) holder;
742     try {
743         hmat::TreeProcedure<hmat::HMatrix<T> > *functor = (hmat::TreeProcedure<hmat::HMatrix<T> > *) proc->internal;
744         hmat->walk(functor);
745     } catch (const std::exception& e) {
746         fprintf(stderr, "%s\n", e.what());
747         return 1;
748     }
749     return 0;
750 }
751 
752 template <typename T, template <typename> class E>
apply_on_leaf(hmat_matrix_t * holder,const hmat_leaf_procedure_t * proc)753 int apply_on_leaf(hmat_matrix_t* holder, const hmat_leaf_procedure_t* proc) {
754   DECLARE_CONTEXT;
755     hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *) holder;
756     try {
757         const hmat::LeafProcedure<hmat::HMatrix<T> > *functor = static_cast<const hmat::LeafProcedure<hmat::HMatrix<T> > *>(proc->internal);
758         hmat->apply_on_leaf(*functor);
759     } catch (const std::exception& e) {
760         fprintf(stderr, "%s\n", e.what());
761         return 1;
762     }
763     return 0;
764 }
765 
766 template <typename T, template <typename> class E>
read_struct(hmat_iostream readfunc,void * user_data)767 hmat_matrix_t * read_struct(hmat_iostream readfunc, void * user_data) {
768     hmat::MatrixStructUnmarshaller<T> unmarshaller(&hmat::HMatSettings::getInstance(), readfunc, user_data);
769     hmat::HMatrix<T> * m = unmarshaller.read();
770     E<T>* engine = new E<T>();
771     hmat::HMatInterface<T> * r = new hmat::HMatInterface<T>(engine, m, unmarshaller.factorization());
772     return (hmat_matrix_t*) r;
773 }
774 
775 template <typename T, template <typename> class E>
read_data(hmat_matrix_t * matrix,hmat_iostream readfunc,void * user_data)776 void read_data(hmat_matrix_t * matrix, hmat_iostream readfunc, void * user_data) {
777     hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
778     hmat::MatrixDataUnmarshaller<T>(readfunc, user_data).read(hmi->engine().hmat);
779 }
780 
781 template <typename T, template <typename> class E>
write_struct(hmat_matrix_t * matrix,hmat_iostream writefunc,void * user_data)782 void write_struct(hmat_matrix_t* matrix, hmat_iostream writefunc, void * user_data) {
783     hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
784     hmat::MatrixStructMarshaller<T>(writefunc, user_data).write(
785         hmi->engine().hmat, hmi->factorization());
786 }
787 
788 template <typename T, template <typename> class E>
write_data(hmat_matrix_t * matrix,hmat_iostream writefunc,void * user_data)789 void write_data(hmat_matrix_t* matrix, hmat_iostream writefunc, void * user_data) {
790     hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
791     hmat::MatrixDataMarshaller<T>(writefunc, user_data).write(hmi->engine().hmat);
792 }
793 
794 template <typename T>
set_progressbar(hmat_matrix_t * matrix,hmat_progress_t * progress)795 void set_progressbar(hmat_matrix_t * matrix, hmat_progress_t * progress) {
796     reinterpret_cast<hmat::HMatInterface<T> *>(matrix)->progress(progress);
797 }
798 
799 }  // end anonymous namespace
800 
801 namespace hmat {
802 
803 template<typename T, template <typename> class E>
createCInterface(hmat_interface_t * i)804 static void createCInterface(hmat_interface_t * i)
805 {
806   DECLARE_CONTEXT;
807     i->copy = copy<T, E>;
808     i->copy_struct = copy_struct<T, E>;
809     i->create_empty_hmatrix_admissibility = create_empty_hmatrix_admissibility<T, E>;
810     i->destroy = destroy<T, E>;
811     i->get_child = get_child<T, E>;
812     i->destroy_child = destroy_child<T, E>;
813     i->inverse = inverse<T, E>;
814     i->finalize = finalize<T, E>;
815     i->full_gemm = full_gemm<T, E>;
816     i->gemm = gemm<T, E>;
817     i->gemv = gemv<T, E>;
818     i->gemm_scalar = gemm_scalar<T, E>;
819     i->add_identity = add_identity<T, E>;
820     i->init = init<T, E>;
821     i->norm = norm<T, E>;
822     i->scale = scale<T, E>;
823     i->solve_mat = solve_mat<T, E>;
824     i->solve_systems = solve_systems<T, E>;
825     i->solve_dense = solve_dense<T, E>;
826     i->transpose = transpose<T, E>;
827     i->internal = NULL;
828     i->get_info  = hmat_get_info<T, E>;
829     i->dump_info = hmat_dump_info<T, E>;
830     i->get_cluster_trees = get_cluster_trees<T, E>;
831     i->set_cluster_trees = set_cluster_trees<T, E>;
832     i->own_cluster_trees = own_cluster_trees<T, E>;
833     i->set_low_rank_epsilon = set_low_rank_epsilon<T, E>;
834     i->extract_diagonal = extract_diagonal<T, E>;
835     i->solve_lower_triangular = solve_lower_triangular<T, E>;
836     i->solve_lower_triangular_dense = solve_lower_triangular_dense<T, E>;
837     i->assemble_generic = assemble_generic<T, E>;
838     i->factorize_generic = factorize_generic<T, E>;
839     i->get_values = get_values<T, E>;
840     i->get_block = get_block<T, E>;
841     i->walk = walk<T, E>;
842     i->read_struct = read_struct<T, E>;
843     i->write_struct = write_struct<T, E>;
844     i->write_data = write_data<T, E>;
845     i->read_data = read_data<T, E>;
846     i->apply_on_leaf = apply_on_leaf<T, E>;
847     i->axpy = axpy<T, E>;
848     i->trsm = trsm<T, E>;
849     i->truncate = truncate<T, E>;
850     i->set_progressbar = set_progressbar<T>;
851     i->gemm_dense = gemm_dense<T, E>;
852     i->vector_reorder = vector_reorder<T, E>;
853     i->vector_restore = vector_restore<T, E>;
854 }
855 
856 }  // end namespace hmat
857 
858 #endif  // _C_WRAPPING_HPP
859