1 /*
2 HMat-OSS (HMatrix library, open source software)
3
4 Copyright (C) 2014-2015 Airbus Group SAS
5
6 This program is free software; you can redistribute it and/or
7 modify it under the terms of the GNU General Public License
8 as published by the Free Software Foundation; either version 2
9 of the License, or (at your option) any later version.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
19
20 http://github.com/jeromerobert/hmat-oss
21 */
22
23 #ifndef _C_WRAPPING_HPP
24 #define _C_WRAPPING_HPP
25
26 #include <string>
27 #include <cstring>
28
29 #include "common/context.hpp"
30 #include "common/my_assert.h"
31 #include "full_matrix.hpp"
32 #include "h_matrix.hpp"
33 #include "uncompressed_values.hpp"
34 #include "serialization.hpp"
35 #include "hmat_cpp_interface.hpp"
36 #include "disable_threading.hpp"
37
38 namespace
39 {
40 template<typename T, template <typename> class E>
create_empty_hmatrix_admissibility(const hmat_cluster_tree_t * rows_tree,const hmat_cluster_tree_t * cols_tree,int lower_sym,hmat_admissibility_t * condition)41 hmat_matrix_t * create_empty_hmatrix_admissibility(
42 const hmat_cluster_tree_t* rows_tree,
43 const hmat_cluster_tree_t* cols_tree, int lower_sym,
44 hmat_admissibility_t* condition)
45 {
46 DECLARE_CONTEXT;
47 hmat::SymmetryFlag sym = lower_sym ? hmat::kLowerSymmetric : hmat::kNotSymmetric;
48 hmat::IEngine<T>* engine = new E<T>();
49 return (hmat_matrix_t*) new hmat::HMatInterface<T>(
50 engine,
51 reinterpret_cast<const hmat::ClusterTree*>(rows_tree),
52 reinterpret_cast<const hmat::ClusterTree*>(cols_tree),
53 sym, (hmat::AdmissibilityCondition*)condition);
54 }
55
56 template<typename T, template <typename> class E>
assemble_generic(hmat_matrix_t * matrix,hmat_assemble_context_t * ctx)57 int assemble_generic(hmat_matrix_t* matrix, hmat_assemble_context_t * ctx) {
58 DECLARE_CONTEXT;
59 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)matrix;
60 bool assembleOnly = ctx->factorization == hmat_factorization_none;
61 hmat::SymmetryFlag sf = ctx->lower_symmetric ? hmat::kLowerSymmetric : hmat::kNotSymmetric;
62 try {
63 if (ctx->lower_symmetric) {
64 HMAT_ASSERT(hmat->engine().hmat->rowsTree() == hmat->engine().hmat->colsTree());
65 }
66 HMAT_ASSERT_MSG(ctx->compression, "No compression algorithm defined in hmat_assemble_context_t");
67 hmat::CompressionAlgorithm* compression = (hmat::CompressionAlgorithm*)ctx->compression;
68 if(ctx->assembly != NULL) {
69 HMAT_ASSERT(ctx->block_compute == NULL && ctx->advanced_compute == NULL && ctx->simple_compute == NULL);
70 hmat::Assembly<T> * cppAssembly = (hmat::Assembly<T> *)ctx->assembly;
71 hmat->assemble(*cppAssembly, sf, ctx->progress);
72 } else if(ctx->block_compute != NULL || ctx->advanced_compute != NULL) {
73 HMAT_ASSERT(ctx->simple_compute == NULL && ctx->assembly == NULL);
74 HMAT_ASSERT(ctx->prepare != NULL);
75 hmat::BlockFunction<T> blockFunction(hmat->rows(), hmat->cols(),
76 ctx->user_context, ctx->prepare, ctx->block_compute, ctx->advanced_compute);
77 hmat::AssemblyFunction<T, hmat::BlockFunction> * f =
78 new hmat::AssemblyFunction<T, hmat::BlockFunction>(blockFunction, compression);
79 hmat->assemble(*f, sf, true, ctx->progress, true);
80 } else if(ctx->simple_compute != NULL) {
81 HMAT_ASSERT(ctx->block_compute == NULL && ctx->advanced_compute == NULL && ctx->assembly == NULL);
82 hmat::AssemblyFunction<T, hmat::SimpleFunction> * f =
83 new hmat::AssemblyFunction<T, hmat::SimpleFunction>(
84 hmat::SimpleFunction<T>(ctx->simple_compute, ctx->user_context), compression);
85 hmat->assemble(*f, sf, true, ctx->progress, true);
86 } else
87 HMAT_ASSERT_MSG(0, "No valid assembly method in assemble_generic()");
88
89 if(!assembleOnly)
90 hmat->factorize(hmat::convert_int_to_factorization(ctx->factorization), ctx->progress);
91 } catch (const std::exception& e) {
92 fprintf(stderr, "%s\n", e.what());
93 return 1;
94 }
95 return 0;
96 }
97
98 template<typename T, template <typename> class E>
copy(hmat_matrix_t * holder)99 hmat_matrix_t* copy(hmat_matrix_t* holder) {
100 DECLARE_CONTEXT;
101 return (hmat_matrix_t*) ((hmat::HMatInterface<T>*) holder)->copy();
102 }
103
104 template<typename T, template <typename> class E>
copy_struct(hmat_matrix_t * holder)105 hmat_matrix_t* copy_struct(hmat_matrix_t* holder) {
106 DECLARE_CONTEXT;
107 return (hmat_matrix_t*) ((hmat::HMatInterface<T>*) holder)->copy(true);
108 }
109
110 template<typename T, template <typename> class E>
destroy(hmat_matrix_t * holder)111 int destroy(hmat_matrix_t* holder) {
112 DECLARE_CONTEXT;
113 delete (hmat::HMatInterface<T>*)(holder);
114 return 0;
115 }
116
117 template<typename T, template <typename> class E>
destroy_child(hmat_matrix_t * holder)118 int destroy_child(hmat_matrix_t* holder) {
119 DECLARE_CONTEXT;
120 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
121 hmat->setHMatrix();
122 delete hmat;
123 return 0;
124 }
125
126 template<typename T, template <typename> class E>
inverse(hmat_matrix_t * holder)127 int inverse(hmat_matrix_t* holder) {
128 DECLARE_CONTEXT;
129 try {
130 ((hmat::HMatInterface<T>*) holder)->inverse();
131 } catch (const std::exception& e) {
132 fprintf(stderr, "%s\n", e.what());
133 return 1;
134 }
135 return 0;
136 }
137
138 template<typename T, template <typename> class E>
factorize_generic(hmat_matrix_t * holder,hmat_factorization_context_t * ctx)139 int factorize_generic(hmat_matrix_t* holder, hmat_factorization_context_t * ctx) {
140 DECLARE_CONTEXT;
141 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
142 try {
143 hmat->factorize(hmat::convert_int_to_factorization(ctx->factorization), ctx->progress);
144 } catch (const std::exception& e) {
145 fprintf(stderr, "%s\n", e.what());
146 return 1;
147 }
148 return 0;
149 }
150
151 template<typename T, template <typename> class E>
factor(hmat_matrix_t * holder,hmat_factorization_t t)152 int factor(hmat_matrix_t* holder, hmat_factorization_t t) {
153 DECLARE_CONTEXT;
154 hmat_factorization_context_t ctx;
155 hmat_factorization_context_init(&ctx);
156 ctx.factorization = t;
157 return factorize_generic<T, E>(holder, &ctx);
158 }
159
160 template<typename T, template <typename> class E>
finalize()161 int finalize() {
162 DECLARE_CONTEXT;
163 E<T>::finalize();
164 return 0;
165 }
166
167 template<typename T, template <typename> class E>
full_gemm(char transA,char transB,int mc,int nc,void * c,void * alpha,void * a,hmat_matrix_t * holder,void * beta)168 int full_gemm(char transA, char transB, int mc, int nc, void* c,
169 void* alpha, void* a, hmat_matrix_t * holder, void* beta) {
170 DECLARE_CONTEXT;
171
172 try {
173 const hmat::HMatInterface<T>* b = (hmat::HMatInterface<T>*)holder;
174 hmat::ScalarArray<T> matC((T*)c, mc, nc);
175 hmat::ScalarArray<T>* matA = NULL;
176 const hmat::ClusterData* bDataRows = (transB == 'N' ? b->rows(): b->cols());
177 const hmat::ClusterData* bDataCols = (transB == 'N' ? b->cols(): b->rows());
178 hmat::reorderVector(&matC, bDataCols->indices(), 1);
179 if (transA == 'N') {
180 matA = new hmat::ScalarArray<T>((T*)a, mc, bDataRows->size());
181 hmat::reorderVector(matA, bDataRows->indices(), 1);
182 } else {
183 matA = new hmat::ScalarArray<T>((T*)a, bDataRows->size(), mc);
184 hmat::reorderVector(matA, bDataRows->indices(), 0);
185 }
186 hmat::HMatInterface<T>::gemm(matC, transA, transB, *((T*)alpha), *matA, *b, *((T*)beta));
187 hmat::restoreVectorOrder(&matC, bDataCols->indices(), 1);
188 if (transA == 'N') {
189 hmat::restoreVectorOrder(matA, bDataRows->indices(), 1);
190 } else {
191 hmat::restoreVectorOrder(matA, bDataRows->indices(), 0);
192 }
193 delete matA;
194 } catch (const std::exception& e) {
195 fprintf(stderr, "%s\n", e.what());
196 return 1;
197 }
198 return 0;
199 }
200
201 template<typename T, template <typename> class E>
gemm(char trans_a,char trans_b,void * alpha,hmat_matrix_t * holder,hmat_matrix_t * holder_b,void * beta,hmat_matrix_t * holder_c)202 int gemm(char trans_a, char trans_b, void *alpha, hmat_matrix_t * holder,
203 hmat_matrix_t * holder_b, void *beta, hmat_matrix_t * holder_c) {
204 DECLARE_CONTEXT;
205 hmat::HMatInterface<T>* hmat_a = (hmat::HMatInterface<T>*)holder;
206 hmat::HMatInterface<T>* hmat_b = (hmat::HMatInterface<T>*)holder_b;
207 hmat::HMatInterface<T>* hmat_c = (hmat::HMatInterface<T>*)holder_c;
208 try {
209 hmat_c->gemm(trans_a, trans_b, *((T*)alpha), hmat_a, hmat_b, *((T*)beta));
210 } catch (const std::exception& e) {
211 fprintf(stderr, "%s\n", e.what());
212 return 1;
213 }
214 return 0;
215 }
216
217 template<typename T, template <typename> class E>
axpy(void * a,hmat_matrix_t * x,hmat_matrix_t * y)218 int axpy(void *a, hmat_matrix_t * x, hmat_matrix_t * y) {
219 DECLARE_CONTEXT;
220 DISABLE_THREADING_IN_BLOCK;
221 hmat::HMatInterface<T>* hmat_x = reinterpret_cast<hmat::HMatInterface<T>*>(x);
222 hmat::HMatInterface<T>* hmat_y = reinterpret_cast<hmat::HMatInterface<T>*>(y);
223 try {
224 hmat_y->engine().hmat->axpy(*((T*)a), hmat_x->engine().hmat);
225 } catch (const std::exception& e) {
226 fprintf(stderr, "%s\n", e.what());
227 return 1;
228 }
229 return 0;
230 }
231
232 template<typename T, template <typename> class E>
gemv(char trans_a,void * alpha,hmat_matrix_t * holder,void * vec_b,void * beta,void * vec_c,int nrhs)233 int gemv(char trans_a, void* alpha, hmat_matrix_t * holder, void* vec_b,
234 void* beta, void* vec_c, int nrhs) {
235 DECLARE_CONTEXT;
236 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
237 const hmat::ClusterData* bData = (trans_a == 'N' ? hmat->cols(): hmat->rows());
238 const hmat::ClusterData* cData = (trans_a == 'N' ? hmat->rows(): hmat->cols());
239 try {
240 hmat::ScalarArray<T> mb((T*) vec_b, bData->size(), nrhs);
241 hmat::ScalarArray<T> mc((T*) vec_c, cData->size(), nrhs);
242 hmat::reorderVector(&mb, bData->indices(), 0);
243 hmat::reorderVector(&mc, cData->indices(), 0);
244 hmat->gemv(trans_a, *((T*)alpha), mb, *((T*)beta), mc);
245 hmat::restoreVectorOrder(&mb, bData->indices(), 0);
246 hmat::restoreVectorOrder(&mc, cData->indices(), 0);
247 } catch (const std::exception& e) {
248 fprintf(stderr, "%s\n", e.what());
249 return 1;
250 }
251 return 0;
252 }
253
254 template<typename T, template <typename> class E>
gemm_scalar(char trans_a,void * alpha,hmat_matrix_t * holder,void * vec_b,void * beta,void * vec_c,int nrhs)255 int gemm_scalar( char trans_a, void* alpha, hmat_matrix_t * holder, void* vec_b,
256 void* beta, void* vec_c, int nrhs ) {
257 DECLARE_CONTEXT;
258 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
259 const hmat::ClusterData* bData = (trans_a == 'N' ? hmat->cols(): hmat->rows());
260 const hmat::ClusterData* cData = (trans_a == 'N' ? hmat->rows(): hmat->cols());
261 try {
262 hmat::ScalarArray<T> mb((T*) vec_b, bData->size(), nrhs);
263 hmat::ScalarArray<T> mc((T*) vec_c, cData->size(), nrhs);
264
265 hmat->gemm_scalar(trans_a, *((T*)alpha), mb, *((T*)beta), mc);
266 } catch (const std::exception& e) {
267 fprintf(stderr, "%s\n", e.what());
268 return 1;
269 }
270 return 0;
271 }
272
is_trans(char trans)273 inline bool is_trans(char trans) {
274 return trans == 'T' || trans == 'C';
275 }
276
is_conj(char trans)277 inline bool is_conj(char trans) {
278 return trans == 'J' || trans == 'C';
279 }
280
switch_flag_trans(char trans)281 inline char switch_flag_trans(char trans) {
282 switch (trans) {
283 case 'N': return 'T';
284 case 'T': return 'N';
285 case 'C': return 'J';
286 case 'J': return 'C';
287 default: HMAT_ASSERT(false);
288 }
289 }
290
switch_flag_conj(char trans)291 inline char switch_flag_conj(char trans) {
292 switch (trans) {
293 case 'N': return 'J';
294 case 'T': return 'C';
295 case 'C': return 'T';
296 case 'J': return 'N';
297 default: HMAT_ASSERT(false);
298 }
299 }
300
301 template<typename T, template <typename> class E>
gemm_dense(char trans_b,char trans_x,char side,void * alpha,hmat_matrix_t * holder,void * vec_x,void * beta,void * vec_y,int nrhs)302 int gemm_dense(char trans_b, char trans_x, char side, void* alpha, hmat_matrix_t* holder,
303 void* vec_x, void* beta, void* vec_y, int nrhs) {
304 char trans_y = 'N';
305 T alphaT = *((T*)alpha);
306 T betaT = *((T*)beta);
307 alpha = &alphaT;
308 beta = &betaT;
309 if (side == 'R') {
310 // Y <- alpha X * B + beta Y <=> Y^t <- alpha B^t X^t + beta Y^t or Y^H <- bar(alpha) B^H * X^H + bar(beta) Y^H
311 if (trans_b == 'C') {
312 trans_b = 'N';
313 trans_x = switch_flag_conj(switch_flag_trans(trans_x));
314 trans_y = 'C';
315 alphaT = hmat::conj(alphaT);
316 betaT = hmat::conj(betaT);
317 } else {
318 trans_b = switch_flag_trans(trans_b);
319 trans_x = switch_flag_trans(trans_x);
320 trans_y = 'T';
321 }
322 }
323
324 DECLARE_CONTEXT;
325 DISABLE_THREADING_IN_BLOCK;
326
327 // Now side='L': op(Y) <- alpha op(B) op(X) + beta op(Y)
328 const hmat::HMatInterface<T>* b = (hmat::HMatInterface<T>*)holder;
329 const hmat::IndexSet* bDataRows = !is_trans(trans_b) ? b->rows(): b->cols();
330 const hmat::IndexSet* bDataCols = !is_trans(trans_b) ? b->cols(): b->rows();
331 try {
332 hmat::ScalarArray<T>* mx = !is_trans(trans_x) ?
333 new hmat::ScalarArray<T>((T*) vec_x, bDataCols->size(), nrhs) :
334 new hmat::ScalarArray<T>((T*) vec_x, nrhs, bDataCols->size());
335 hmat::ScalarArray<T>* my = !is_trans(trans_y) ?
336 new hmat::ScalarArray<T>((T*) vec_y, bDataRows->size(), nrhs) :
337 new hmat::ScalarArray<T>((T*) vec_y, nrhs, bDataRows->size());
338
339 // Apply transformations on x and y
340 if (is_trans(trans_x))
341 mx->transpose();
342 if (is_conj(trans_x))
343 mx->conjugate();
344 if (is_trans(trans_y))
345 my->transpose();
346 if (is_conj(trans_y))
347 my->conjugate();
348
349 b->gemv(trans_b, *((T*)alpha), *mx, *((T*)beta), *my);
350
351 // Apply inverse transformations on x and y
352 if (is_trans(trans_x))
353 mx->transpose();
354 if (is_trans(trans_y))
355 my->transpose();
356 if (is_conj(trans_y))
357 my->conjugate();
358
359 delete mx;
360 delete my;
361 } catch (const std::exception& e) {
362 fprintf(stderr, "%s\n", e.what());
363 return 1;
364 }
365 return 0;
366 }
367
368 template<typename T, template <typename> class E>
trsm(char side,char uplo,char transa,char diag,int m,int n,void * alpha,hmat_matrix_t * A,int is_b_hmat,void * B)369 int trsm( char side, char uplo, char transa, char diag, int m, int n,
370 void *alpha, hmat_matrix_t *A, int is_b_hmat, void *B )
371 {
372 DECLARE_CONTEXT;
373 hmat::HMatInterface<T>* hmatA = (hmat::HMatInterface<T>*)A;
374
375 try {
376 if ( is_b_hmat ) {
377 hmat::HMatInterface<T>* hmatB = (hmat::HMatInterface<T>*)B;
378 hmatA->trsm( side, uplo, transa, diag, *((T*)alpha), hmatB );
379 }
380 else {
381 bool isleft = (side == 'l') || (side == 'L');
382 hmat::ScalarArray<T> mB( (T*)B, (isleft ? m : n), (isleft ? n : m ) );
383 hmatA->trsm( side, uplo, transa, diag, *((T*)alpha), mB );
384 }
385 } catch (const std::exception& e) {
386 fprintf(stderr, "%s\n", e.what());
387 return 1;
388 }
389 return 0;
390 }
391
392 template<typename T, template <typename> class E>
add_identity(hmat_matrix_t * holder,void * alpha)393 int add_identity(hmat_matrix_t* holder, void *alpha) {
394 DECLARE_CONTEXT;
395 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
396 try {
397 hmat->addIdentity(*((T*)alpha));
398 } catch (const std::exception& e) {
399 fprintf(stderr, "%s\n", e.what());
400 return 1;
401 }
402 return 0;
403 }
404
405 template<typename T, template <typename> class E>
init()406 int init() {
407 DECLARE_CONTEXT;
408 return E<T>::init();
409 }
410
411 template<typename T, template <typename> class E>
norm(hmat_matrix_t * holder)412 double norm(hmat_matrix_t* holder) {
413 DECLARE_CONTEXT;
414 return ((hmat::HMatInterface<T>*)holder)->norm();
415 }
416
417 template<typename T, template <typename> class E>
scale(void * alpha,hmat_matrix_t * holder)418 int scale(void *alpha, hmat_matrix_t* holder) {
419 DECLARE_CONTEXT;
420 try {
421 ((hmat::HMatInterface<T>*)holder)->scale(*((T*)alpha));
422 } catch (const std::exception& e) {
423 fprintf(stderr, "%s\n", e.what());
424 return 1;
425 }
426 return 0;
427 }
428
429 template<typename T, template <typename> class E>
truncate(hmat_matrix_t * holder)430 int truncate(hmat_matrix_t* holder) {
431 DECLARE_CONTEXT;
432 try {
433 ((hmat::HMatInterface<T>*)holder)->truncate();
434 } catch (const std::exception& e) {
435 fprintf(stderr, "%s\n", e.what());
436 return 1;
437 }
438 return 0;
439 }
440
441 template<typename T, template <typename> class E>
vector_reorder(void * vec_b,const hmat_cluster_tree_t * rows_ct,int rows,const hmat_cluster_tree_t * cols_ct,int cols)442 int vector_reorder(void* vec_b, const hmat_cluster_tree_t *rows_ct, int rows, const hmat_cluster_tree_t *cols_ct, int cols) {
443 DECLARE_CONTEXT;
444 try {
445 HMAT_ASSERT_MSG(rows_ct != NULL || rows != 0, "either row cluster tree or rows must be non null");
446 HMAT_ASSERT_MSG(cols_ct != NULL || cols != 0, "either col cluster tree or cols must be non null");
447 const hmat::ClusterTree *clusterTreeRows = reinterpret_cast<const hmat::ClusterTree*>(rows_ct);
448 const hmat::ClusterTree *clusterTreeCols = reinterpret_cast<const hmat::ClusterTree*>(cols_ct);
449 int nrows = clusterTreeRows == NULL ? rows : clusterTreeRows->data.size();
450 int ncols = clusterTreeCols == NULL ? cols : clusterTreeCols->data.size();
451 hmat::ScalarArray<T> mb((T*) vec_b, nrows, ncols);
452 if (clusterTreeRows) {
453 hmat::reorderVector(&mb, clusterTreeRows->data.indices(), 0);
454 }
455 if (clusterTreeCols) {
456 hmat::reorderVector(&mb, clusterTreeCols->data.indices(), 1);
457 }
458 } catch (const std::exception& e) {
459 fprintf(stderr, "%s\n", e.what());
460 return 1;
461 }
462 return 0;
463 }
464
465 template<typename T, template <typename> class E>
vector_restore(void * vec_b,const hmat_cluster_tree_t * rows_ct,int rows,const hmat_cluster_tree_t * cols_ct,int cols)466 int vector_restore(void* vec_b, const hmat_cluster_tree_t *rows_ct, int rows, const hmat_cluster_tree_t *cols_ct, int cols) {
467 DECLARE_CONTEXT;
468 try {
469 HMAT_ASSERT_MSG(rows_ct != NULL || rows != 0, "either row cluster tree or rows must be non null");
470 HMAT_ASSERT_MSG(cols_ct != NULL || cols != 0, "either col cluster tree or cols must be non null");
471 const hmat::ClusterTree *clusterTreeRows = reinterpret_cast<const hmat::ClusterTree*>(rows_ct);
472 const hmat::ClusterTree *clusterTreeCols = reinterpret_cast<const hmat::ClusterTree*>(cols_ct);
473 int nrows = clusterTreeRows == NULL ? rows : clusterTreeRows->data.size();
474 int ncols = clusterTreeCols == NULL ? cols : clusterTreeCols->data.size();
475 hmat::ScalarArray<T> mb((T*) vec_b, nrows, ncols);
476 if (clusterTreeRows) {
477 hmat::restoreVectorOrder(&mb, clusterTreeRows->data.indices(), 0);
478 }
479 if (clusterTreeCols) {
480 hmat::restoreVectorOrder(&mb, clusterTreeCols->data.indices(), 1);
481 }
482 } catch (const std::exception& e) {
483 fprintf(stderr, "%s\n", e.what());
484 return 1;
485 }
486 return 0;
487 }
488
489 template<typename T, template <typename> class E>
solve_mat(hmat_matrix_t * hmat,hmat_matrix_t * hmatB)490 int solve_mat(hmat_matrix_t* hmat, hmat_matrix_t* hmatB) {
491 DECLARE_CONTEXT;
492 try {
493 ((hmat::HMatInterface<T>*)hmat)->solve(*(hmat::HMatInterface<T>*)hmatB);
494 } catch (const std::exception& e) {
495 fprintf(stderr, "%s\n", e.what());
496 return 1;
497 }
498 return 0;
499 }
500
501 template<typename T, template <typename> class E>
solve_systems(hmat_matrix_t * holder,void * b,int nrhs)502 int solve_systems(hmat_matrix_t* holder, void* b, int nrhs) {
503 DECLARE_CONTEXT;
504 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
505 try {
506 hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
507 hmat::reorderVector<T>(&mb, hmat->cols()->indices(), 0);
508 hmat->solve(mb);
509 hmat::restoreVectorOrder<T>(&mb, hmat->cols()->indices(), 0);
510 } catch (const std::exception& e) {
511 fprintf(stderr, "%s\n", e.what());
512 return 1;
513 }
514 return 0;
515 }
516
517 template<typename T, template <typename> class E>
solve_dense(hmat_matrix_t * holder,void * b,int nrhs)518 int solve_dense(hmat_matrix_t* holder, void* b, int nrhs) {
519 DECLARE_CONTEXT;
520 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
521 try {
522 hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
523 hmat->solve(mb);
524 } catch (const std::exception& e) {
525 fprintf(stderr, "%s\n", e.what());
526 return 1;
527 }
528 return 0;
529 }
530
531 template<typename T, template <typename> class E>
transpose(hmat_matrix_t * hmat)532 int transpose(hmat_matrix_t* hmat) {
533 DECLARE_CONTEXT;
534 try {
535 ((hmat::HMatInterface<T>*)hmat)->transpose();
536 } catch (const std::exception& e) {
537 fprintf(stderr, "%s\n", e.what());
538 return 1;
539 }
540 return 0;
541 }
542
543 template<typename T, template <typename> class E>
hmat_get_info(hmat_matrix_t * holder,hmat_info_t * info)544 int hmat_get_info(hmat_matrix_t* holder, hmat_info_t* info) {
545 DECLARE_CONTEXT;
546 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
547 try {
548 hmat->info(*info);
549 } catch (const std::exception& e) {
550 fprintf(stderr, "%s\n", e.what());
551 return 1;
552 }
553 return 0;
554 }
555
556 template<typename T, template <typename> class E>
hmat_dump_info(hmat_matrix_t * holder,char * prefix)557 int hmat_dump_info(hmat_matrix_t* holder, char* prefix) {
558 DECLARE_CONTEXT;
559 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
560 try {
561 std::string filejson(prefix);
562 filejson += ".json";
563 hmat->dumpTreeToFile( filejson );
564 } catch (const std::exception& e) {
565 fprintf(stderr, "%s\n", e.what());
566 return 1;
567 }
568 return 0;
569 }
570
571 template<typename T, template <typename> class E>
get_cluster_trees(hmat_matrix_t * holder,const hmat_cluster_tree_t ** rows,const hmat_cluster_tree_t ** cols)572 int get_cluster_trees(hmat_matrix_t* holder, const hmat_cluster_tree_t ** rows, const hmat_cluster_tree_t ** cols) {
573 DECLARE_CONTEXT;
574 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
575 try {
576 if (rows)
577 *rows = static_cast<const hmat_cluster_tree_t*>(static_cast<const void*>(hmat->engine().hmat->rowsTree()));
578 if (cols)
579 *cols = static_cast<const hmat_cluster_tree_t*>(static_cast<const void*>(hmat->engine().hmat->colsTree()));
580 } catch (const std::exception& e) {
581 fprintf(stderr, "%s\n", e.what());
582 return 1;
583 }
584 return 0;
585 }
586
587 template<typename T, template <typename> class E>
set_cluster_trees(hmat_matrix_t * holder,const hmat_cluster_tree_t * rows,const hmat_cluster_tree_t * cols)588 int set_cluster_trees(hmat_matrix_t* holder, const hmat_cluster_tree_t * rows, const hmat_cluster_tree_t * cols) {
589 DECLARE_CONTEXT;
590 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
591 try {
592 hmat->engine().hmat->setClusterTrees(
593 reinterpret_cast<const hmat::ClusterTree*>(rows),
594 reinterpret_cast<const hmat::ClusterTree*>(cols));
595 } catch (const std::exception& e) {
596 fprintf(stderr, "%s\n", e.what());
597 return 1;
598 }
599 return 0;
600 }
601
602 template<typename T, template <typename> class E>
own_cluster_trees(hmat_matrix_t * holder,int owns_row,int owns_col)603 void own_cluster_trees(hmat_matrix_t* holder, int owns_row, int owns_col)
604 {
605 DECLARE_CONTEXT;
606 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
607 hmat->engine().hmat->ownClusterTrees(owns_row != 0, owns_col != 0);
608 }
609
610 template<typename T, template <typename> class E>
set_low_rank_epsilon(hmat_matrix_t * holder,double epsilon)611 void set_low_rank_epsilon(hmat_matrix_t* holder, double epsilon)
612 {
613 DECLARE_CONTEXT;
614 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
615 hmat->engine().hmat->lowRankEpsilon(epsilon);
616 }
617
618 template<typename T, template <typename> class E>
extract_diagonal(hmat_matrix_t * holder,void * diag,int size)619 int extract_diagonal(hmat_matrix_t* holder, void* diag, int size)
620 {
621 DECLARE_CONTEXT;
622 (void)size; //for API compatibility
623 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*) holder;
624 try {
625 hmat->engine().hmat->extractDiagonal(static_cast<T*>(diag));
626 hmat::ScalarArray<T> permutedDiagonal(static_cast<T*>(diag), hmat->cols()->size(), 1);
627 hmat::restoreVectorOrder(&permutedDiagonal, hmat->cols()->indices(), 0);
628 } catch (const std::exception& e) {
629 fprintf(stderr, "%s\n", e.what());
630 return 1;
631 }
632 return 0;
633 }
634
635 template<typename T, template <typename> class E>
solve_lower_triangular(hmat_matrix_t * holder,int transpose,void * b,int nrhs)636 int solve_lower_triangular(hmat_matrix_t* holder, int transpose, void* b, int nrhs)
637 {
638 DECLARE_CONTEXT;
639 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
640 hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
641 try {
642 if (transpose)
643 hmat::reorderVector<T>(&mb, hmat->rows()->indices(), 0);
644 else
645 hmat::reorderVector<T>(&mb, hmat->cols()->indices(), 0);
646 hmat->solveLower(mb, transpose);
647 if (transpose)
648 hmat::restoreVectorOrder<T>(&mb, hmat->rows()->indices(), 0);
649 else
650 hmat::restoreVectorOrder<T>(&mb, hmat->cols()->indices(), 0);
651 } catch (const std::exception& e) {
652 fprintf(stderr, "%s\n", e.what());
653 return 1;
654 }
655 return 0;
656 }
657
658 template<typename T, template <typename> class E>
solve_lower_triangular_dense(hmat_matrix_t * holder,int transpose,void * b,int nrhs)659 int solve_lower_triangular_dense(hmat_matrix_t* holder, int transpose, void* b, int nrhs)
660 {
661 DECLARE_CONTEXT;
662 hmat::HMatInterface<T>* hmat = (hmat::HMatInterface<T>*)holder;
663 hmat::ScalarArray<T> mb((T*) b, hmat->cols()->size(), nrhs);
664 try {
665 hmat->solveLower(mb, transpose);
666 } catch (const std::exception& e) {
667 fprintf(stderr, "%s\n", e.what());
668 return 1;
669 }
670 return 0;
671 }
672
673 template <typename T, template <typename> class E>
get_child(hmat_matrix_t * hmatrix,int i,int j)674 hmat_matrix_t *get_child( hmat_matrix_t *hmatrix, int i, int j ) {
675 DECLARE_CONTEXT;
676 hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)hmatrix;
677
678 hmat::HMatrix<T> *m = hmat->get( i, j );
679 hmat::IEngine<T>* engine = new E<T>();
680 hmat::HMatInterface<T> *r = new hmat::HMatInterface<T>( engine, m, hmat->factorization() );
681 return (hmat_matrix_t*) r;
682 }
683
684 template <typename T, template <typename> class E>
get_block(struct hmat_get_values_context_t * ctx)685 int get_block(struct hmat_get_values_context_t *ctx) {
686 DECLARE_CONTEXT;
687 DISABLE_THREADING_IN_BLOCK;
688 hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)ctx->matrix;
689 try {
690 hmat::IndexSet rows(ctx->row_offset, ctx->row_size);
691 hmat::IndexSet cols(ctx->col_offset, ctx->col_size);
692 typename E<T>::UncompressedBlock view;
693 const E<T>& engine = dynamic_cast<const E<T>&>(hmat->engine());
694 view.uncompress(engine.getHandle(), rows, cols, (T*)ctx->values);
695 hmat::HMatrix<T>* compressed = hmat->engine().hmat;
696 // Symmetrize values when requesting a full symmetric matrix
697 if (compressed->isLower &&
698 ctx->row_offset == 0 && ctx->col_offset == 0 &&
699 ctx->row_size == compressed->rows()->size() && ctx->col_size == compressed->cols()->size())
700 {
701 T* ptr = static_cast<T*>(ctx->values);
702 for (int i = 0; i < ctx->row_size; i++) {
703 for (int j = i + 1; j < ctx->col_size; j++) {
704 ptr[j*ctx->row_size + i] = ptr[i*ctx->row_size + j];
705 }
706 }
707 }
708 if (ctx->renumber_rows)
709 view.renumberRows();
710 ctx->col_indices = view.colsNumbering();
711 ctx->row_indices= view.rowsNumbering();
712 } catch (const std::exception& e) {
713 fprintf(stderr, "%s\n", e.what());
714 return 1;
715 }
716 return 0;
717 }
718
719 template <typename T, template <typename> class E>
get_values(struct hmat_get_values_context_t * ctx)720 int get_values(struct hmat_get_values_context_t *ctx) {
721 DECLARE_CONTEXT;
722 // No need to call DISABLE_THREADING_IN_BLOCK here, there is no BLAS call
723 hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *)ctx->matrix;
724 try {
725 typename E<T>::UncompressedValues view;
726 const E<T>& engine = reinterpret_cast<const E<T>&>(hmat->engine());
727 view.uncompress(engine.getHandle(),
728 ctx->row_indices, ctx->row_size,
729 ctx->col_indices, ctx->col_size,
730 (T*)ctx->values);
731 } catch (const std::exception& e) {
732 fprintf(stderr, "%s\n", e.what());
733 return 1;
734 }
735 return 0;
736 }
737
738 template <typename T, template <typename> class E>
walk(hmat_matrix_t * holder,hmat_procedure_t * proc)739 int walk(hmat_matrix_t* holder, hmat_procedure_t* proc) {
740 DECLARE_CONTEXT;
741 hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *) holder;
742 try {
743 hmat::TreeProcedure<hmat::HMatrix<T> > *functor = (hmat::TreeProcedure<hmat::HMatrix<T> > *) proc->internal;
744 hmat->walk(functor);
745 } catch (const std::exception& e) {
746 fprintf(stderr, "%s\n", e.what());
747 return 1;
748 }
749 return 0;
750 }
751
752 template <typename T, template <typename> class E>
apply_on_leaf(hmat_matrix_t * holder,const hmat_leaf_procedure_t * proc)753 int apply_on_leaf(hmat_matrix_t* holder, const hmat_leaf_procedure_t* proc) {
754 DECLARE_CONTEXT;
755 hmat::HMatInterface<T> *hmat = (hmat::HMatInterface<T> *) holder;
756 try {
757 const hmat::LeafProcedure<hmat::HMatrix<T> > *functor = static_cast<const hmat::LeafProcedure<hmat::HMatrix<T> > *>(proc->internal);
758 hmat->apply_on_leaf(*functor);
759 } catch (const std::exception& e) {
760 fprintf(stderr, "%s\n", e.what());
761 return 1;
762 }
763 return 0;
764 }
765
766 template <typename T, template <typename> class E>
read_struct(hmat_iostream readfunc,void * user_data)767 hmat_matrix_t * read_struct(hmat_iostream readfunc, void * user_data) {
768 hmat::MatrixStructUnmarshaller<T> unmarshaller(&hmat::HMatSettings::getInstance(), readfunc, user_data);
769 hmat::HMatrix<T> * m = unmarshaller.read();
770 E<T>* engine = new E<T>();
771 hmat::HMatInterface<T> * r = new hmat::HMatInterface<T>(engine, m, unmarshaller.factorization());
772 return (hmat_matrix_t*) r;
773 }
774
775 template <typename T, template <typename> class E>
read_data(hmat_matrix_t * matrix,hmat_iostream readfunc,void * user_data)776 void read_data(hmat_matrix_t * matrix, hmat_iostream readfunc, void * user_data) {
777 hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
778 hmat::MatrixDataUnmarshaller<T>(readfunc, user_data).read(hmi->engine().hmat);
779 }
780
781 template <typename T, template <typename> class E>
write_struct(hmat_matrix_t * matrix,hmat_iostream writefunc,void * user_data)782 void write_struct(hmat_matrix_t* matrix, hmat_iostream writefunc, void * user_data) {
783 hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
784 hmat::MatrixStructMarshaller<T>(writefunc, user_data).write(
785 hmi->engine().hmat, hmi->factorization());
786 }
787
788 template <typename T, template <typename> class E>
write_data(hmat_matrix_t * matrix,hmat_iostream writefunc,void * user_data)789 void write_data(hmat_matrix_t* matrix, hmat_iostream writefunc, void * user_data) {
790 hmat::HMatInterface<T> * hmi = (hmat::HMatInterface<T> *) matrix;
791 hmat::MatrixDataMarshaller<T>(writefunc, user_data).write(hmi->engine().hmat);
792 }
793
794 template <typename T>
set_progressbar(hmat_matrix_t * matrix,hmat_progress_t * progress)795 void set_progressbar(hmat_matrix_t * matrix, hmat_progress_t * progress) {
796 reinterpret_cast<hmat::HMatInterface<T> *>(matrix)->progress(progress);
797 }
798
799 } // end anonymous namespace
800
801 namespace hmat {
802
803 template<typename T, template <typename> class E>
createCInterface(hmat_interface_t * i)804 static void createCInterface(hmat_interface_t * i)
805 {
806 DECLARE_CONTEXT;
807 i->copy = copy<T, E>;
808 i->copy_struct = copy_struct<T, E>;
809 i->create_empty_hmatrix_admissibility = create_empty_hmatrix_admissibility<T, E>;
810 i->destroy = destroy<T, E>;
811 i->get_child = get_child<T, E>;
812 i->destroy_child = destroy_child<T, E>;
813 i->inverse = inverse<T, E>;
814 i->finalize = finalize<T, E>;
815 i->full_gemm = full_gemm<T, E>;
816 i->gemm = gemm<T, E>;
817 i->gemv = gemv<T, E>;
818 i->gemm_scalar = gemm_scalar<T, E>;
819 i->add_identity = add_identity<T, E>;
820 i->init = init<T, E>;
821 i->norm = norm<T, E>;
822 i->scale = scale<T, E>;
823 i->solve_mat = solve_mat<T, E>;
824 i->solve_systems = solve_systems<T, E>;
825 i->solve_dense = solve_dense<T, E>;
826 i->transpose = transpose<T, E>;
827 i->internal = NULL;
828 i->get_info = hmat_get_info<T, E>;
829 i->dump_info = hmat_dump_info<T, E>;
830 i->get_cluster_trees = get_cluster_trees<T, E>;
831 i->set_cluster_trees = set_cluster_trees<T, E>;
832 i->own_cluster_trees = own_cluster_trees<T, E>;
833 i->set_low_rank_epsilon = set_low_rank_epsilon<T, E>;
834 i->extract_diagonal = extract_diagonal<T, E>;
835 i->solve_lower_triangular = solve_lower_triangular<T, E>;
836 i->solve_lower_triangular_dense = solve_lower_triangular_dense<T, E>;
837 i->assemble_generic = assemble_generic<T, E>;
838 i->factorize_generic = factorize_generic<T, E>;
839 i->get_values = get_values<T, E>;
840 i->get_block = get_block<T, E>;
841 i->walk = walk<T, E>;
842 i->read_struct = read_struct<T, E>;
843 i->write_struct = write_struct<T, E>;
844 i->write_data = write_data<T, E>;
845 i->read_data = read_data<T, E>;
846 i->apply_on_leaf = apply_on_leaf<T, E>;
847 i->axpy = axpy<T, E>;
848 i->trsm = trsm<T, E>;
849 i->truncate = truncate<T, E>;
850 i->set_progressbar = set_progressbar<T>;
851 i->gemm_dense = gemm_dense<T, E>;
852 i->vector_reorder = vector_reorder<T, E>;
853 i->vector_restore = vector_restore<T, E>;
854 }
855
856 } // end namespace hmat
857
858 #endif // _C_WRAPPING_HPP
859