1 /*
2   HMat-OSS (HMatrix library, open source software)
3 
4   Copyright (C) 2014-2015 Airbus Group SAS
5 
6   This program is free software; you can redistribute it and/or
7   modify it under the terms of the GNU General Public License
8   as published by the Free Software Foundation; either version 2
9   of the License, or (at your option) any later version.
10 
11   This program is distributed in the hope that it will be useful,
12   but WITHOUT ANY WARRANTY; without even the implied warranty of
13   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14   GNU General Public License for more details.
15 
16   You should have received a copy of the GNU General Public License
17   along with this program; if not, write to the Free Software
18   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19 
20   http://github.com/jeromerobert/hmat-oss
21 */
22 
23 #include "hmat_cpp_interface.hpp"
24 #include "h_matrix.hpp"
25 #include "admissibility.hpp"
26 #include "cluster_tree.hpp"
27 #include "common/context.hpp"
28 #include "disable_threading.hpp"
29 #include "json.hpp"
30 #include "iengine.hpp"
31 
32 #include <cstring>
33 #include <fstream>
34 
35 namespace hmat {
36 
37 // HMatInterface
38 template<typename T>
HMatInterface(IEngine<T> * engine,const ClusterTree * _rows,const ClusterTree * _cols,SymmetryFlag sym,AdmissibilityCondition * admissibilityCondition)39 HMatInterface<T>::HMatInterface(IEngine<T>* engine, const ClusterTree* _rows, const ClusterTree* _cols,
40                                 SymmetryFlag sym, AdmissibilityCondition * admissibilityCondition) :
41   engine_(engine),factorizationType(Factorization::NONE)
42 {
43   DECLARE_CONTEXT;
44   admissibilityCondition->prepare(*_rows, *_cols);
45   engine_->hmat = new HMatrix<T>(_rows, _cols, &HMatSettings::getInstance(), 0, sym, admissibilityCondition);
46   admissibilityCondition->clean(*_rows, *_cols);
47 }
48 
49 template<typename T>
~HMatInterface()50 HMatInterface<T>::~HMatInterface() {
51   engine_->destroy();
52   delete engine_->hmat;
53   delete engine_;
54 }
55 
56 template<typename T>
HMatInterface(IEngine<T> * engine,HMatrix<T> * h,Factorization factorization)57 HMatInterface<T>::HMatInterface(IEngine<T>* engine, HMatrix<T>* h, Factorization factorization):
58   engine_(engine)
59 {
60   engine_->setHMatrix(h);
61       factorizationType = factorization;
62 }
63 
64 template<typename T>
assemble(Assembly<T> & f,SymmetryFlag sym,bool,hmat_progress_t * progress,bool ownAssembly)65 void HMatInterface<T>::assemble(Assembly<T>& f, SymmetryFlag sym, bool,
66                                    hmat_progress_t * progress, bool ownAssembly) {
67   DISABLE_THREADING_IN_BLOCK;
68   DECLARE_CONTEXT;
69   engine_->progress(progress);
70   engine_->assembly(f, sym, ownAssembly);
71 }
72 
73 template<typename T>
factorize(Factorization t,hmat_progress_t * progress)74 void HMatInterface<T>::factorize(Factorization t, hmat_progress_t * progress) {
75   DISABLE_THREADING_IN_BLOCK;
76   DECLARE_CONTEXT;
77   engine_->progress(progress);
78   if(progress != NULL)
79     progress->max = engine_->hmat->rows()->size();
80   engine_->factorization(t);
81   factorizationType = t;
82   engine_->hmat->checkStructure();
83 }
84 
85 template<typename T>
inverse(hmat_progress_t * progress)86 void HMatInterface<T>::inverse(hmat_progress_t * progress) {
87   DISABLE_THREADING_IN_BLOCK;
88   DECLARE_CONTEXT;
89   engine_->progress(progress);
90   engine_->inverse();
91 }
92 
93 template<typename T>
gemv(char trans,T alpha,ScalarArray<T> & x,T beta,ScalarArray<T> & y) const94 void HMatInterface<T>::gemv(char trans, T alpha, ScalarArray<T>& x, T beta,
95                             ScalarArray<T>& y) const {
96   DISABLE_THREADING_IN_BLOCK;
97   DECLARE_CONTEXT;
98   engine_->gemv(trans, alpha, x, beta, y);
99 }
100 
101 template<typename T>
gemm_scalar(char trans,T alpha,ScalarArray<T> & x,T beta,ScalarArray<T> & y) const102 void HMatInterface<T>::gemm_scalar(char trans, T alpha, ScalarArray<T>& x,
103 				      T beta, ScalarArray<T>& y) const {
104   DISABLE_THREADING_IN_BLOCK;
105   DECLARE_CONTEXT;
106   engine_->gemv( trans, alpha, x, beta, y );
107 }
108 
109 template<typename T>
gemm(char transA,char transB,T alpha,const HMatInterface<T> * a,const HMatInterface<T> * b,T beta)110 void HMatInterface<T>::gemm(char transA, char transB, T alpha,
111                             const HMatInterface<T>* a,
112                             const HMatInterface<T>* b, T beta) {
113     DISABLE_THREADING_IN_BLOCK;
114     DECLARE_CONTEXT;
115     engine_->gemm(transA, transB, alpha, *a->engine_, *b->engine_, beta);
116     engine_->hmat->checkStructure();
117 }
118 
119 template<typename T>
trsm(char side,char uplo,char transa,char diag,T alpha,HMatInterface<T> * B)120 void HMatInterface<T>::trsm( char side, char uplo, char transa, char diag,
121 				T alpha, HMatInterface<T>* B ) {
122     DISABLE_THREADING_IN_BLOCK;
123     DECLARE_CONTEXT;
124     engine_->trsm( side, uplo, transa, diag, alpha, *B->engine_ );
125 }
126 
127 template<typename T>
trsm(char side,char uplo,char transa,char diag,T alpha,ScalarArray<T> & B)128 void HMatInterface<T>::trsm( char side, char uplo, char transa, char diag,
129 				T alpha, ScalarArray<T>& B ) {
130     DISABLE_THREADING_IN_BLOCK;
131     DECLARE_CONTEXT;
132     engine_->trsm( side, uplo, transa, diag, alpha, B );
133 }
134 
135 template<typename T>
gemm(ScalarArray<T> & c,char transA,char transB,T alpha,ScalarArray<T> & a,const HMatInterface<T> & b,T beta)136 void HMatInterface<T>::gemm(ScalarArray<T>& c, char transA, char transB, T alpha,
137                             ScalarArray<T>& a, const HMatInterface<T>& b,
138                             T beta) {
139   DECLARE_CONTEXT;
140   // C <- AB + C  <=>  C^t <- B^t A^t + C^t
141   // On fait les operations dans ce sens pour etre dans le bon sens
142   // pour la memoire, et pour reordonner correctement les "vecteurs" A
143   // et C.
144   if (transA == 'N') {
145     a.transpose();
146   }
147   if ((transA == 'C') != (transB == 'C')) {
148     a.conjugate();
149   }
150   c.transpose();
151   if (transB == 'N') {
152     b.gemv('T', alpha, a, beta, c);
153   } else if (transB == 'T') {
154     b.gemv('N', alpha, a, beta, c);
155   } else {
156     c.conjugate();
157     T alphaC = hmat::conj(alpha);
158     T betaC = hmat::conj(beta);
159     b.gemv('N', alphaC, a, betaC, c);
160     c.conjugate();
161   }
162   c.transpose();
163   if (transA == 'N') {
164     a.transpose();
165   }
166   if ((transA == 'C') != (transB == 'C')) {
167     a.conjugate();
168   }
169 }
170 
171 
172 template<typename T>
solve(ScalarArray<T> & b) const173 void HMatInterface<T>::solve(ScalarArray<T>& b) const {
174   DISABLE_THREADING_IN_BLOCK;
175   DECLARE_CONTEXT;
176   engine_->solve(b, factorizationType);
177 }
178 
179 template<typename T>
solve(HMatInterface<T> & b) const180 void HMatInterface<T>::solve(HMatInterface<T>& b) const {
181   DISABLE_THREADING_IN_BLOCK;
182   DECLARE_CONTEXT;
183   engine_->solve(*b.engine_, factorizationType);
184 }
185 
186 template<typename T>
solveLower(ScalarArray<T> & b,bool transpose) const187 void HMatInterface<T>::solveLower(ScalarArray<T>& b, bool transpose) const {
188   DISABLE_THREADING_IN_BLOCK;
189   DECLARE_CONTEXT;
190   engine_->solveLower(b, factorizationType, transpose);
191 }
192 
193 template<typename T>
copy(bool structOnly) const194 HMatInterface<T>* HMatInterface<T>::copy(bool structOnly) const {
195   DECLARE_CONTEXT;
196   HMatInterface<T>* result = new HMatInterface<T>(engine_->clone(), NULL);
197   engine_->copy(*(result->engine_), structOnly);
198   assert(result->engine_->hmat);
199   result->engine_->hmat->checkStructure();
200   return result;
201 }
202 
203 template<typename T>
transpose()204 void HMatInterface<T>::transpose() {
205   DECLARE_CONTEXT;
206   engine_->transpose();
207   engine_->hmat->checkStructure();
208 }
209 
210 
211 template<typename T>
norm() const212 double HMatInterface<T>::norm() const {
213   DISABLE_THREADING_IN_BLOCK;
214   DECLARE_CONTEXT;
215   return engine_->hmat->norm();
216 }
217 
218 template<typename T>
scale(T alpha)219 void HMatInterface<T>::scale(T alpha) {
220   DISABLE_THREADING_IN_BLOCK;
221   DECLARE_CONTEXT;
222   engine_->scale(alpha);
223 }
224 
225 template<typename T>
truncate()226 void HMatInterface<T>::truncate() {
227   DISABLE_THREADING_IN_BLOCK;
228   DECLARE_CONTEXT;
229   engine_->hmat->truncate();
230 }
231 
232 template<typename T>
addIdentity(T alpha)233 void HMatInterface<T>::addIdentity(T alpha) {
234   DECLARE_CONTEXT;
235   engine_->addIdentity(alpha);
236 }
237 
238 template<typename T>
addRand(double epsilon)239 void HMatInterface<T>::addRand(double epsilon) {
240   DECLARE_CONTEXT;
241   engine_->addRand(epsilon);
242 }
243 
244 template<typename T>
info(hmat_info_t & result) const245 void HMatInterface<T>::info(hmat_info_t & result) const {
246   DECLARE_CONTEXT;
247     memset(&result, 0, sizeof(hmat_info_t));
248     engine_->info(result);
249 }
250 
251 template<typename T>
dumpTreeToFile(const std::string & filename) const252 void HMatInterface<T>::dumpTreeToFile(const std::string& filename) const {
253   DECLARE_CONTEXT;
254   std::ofstream out(filename.c_str());
255   HMatrixJSONDumper<T>(engine_->hmat, out).dump();
256 }
257 
258 template<typename T>
nodesCount() const259 int HMatInterface<T>::nodesCount() const {
260   DISABLE_THREADING_IN_BLOCK;
261   DECLARE_CONTEXT;
262   return engine_->hmat->nodesCount();
263 }
264 
265 template<typename T>
walk(TreeProcedure<HMatrix<T>> * proc)266 void HMatInterface<T>::walk(TreeProcedure<HMatrix<T> > *proc){
267   DISABLE_THREADING_IN_BLOCK;
268   DECLARE_CONTEXT;
269   return engine_->hmat->walk(proc);
270 }
271 
272 template<typename T>
apply_on_leaf(const LeafProcedure<HMatrix<T>> & proc)273 void HMatInterface<T>::apply_on_leaf(const LeafProcedure<HMatrix<T> >& proc){
274   DISABLE_THREADING_IN_BLOCK;
275   DECLARE_CONTEXT;
276   engine_->applyOnLeaf(proc);
277 }
278 
279 template<typename T>
get(int i,int j) const280 HMatrix<T>* HMatInterface<T>::get( int i, int j ) const {
281     DISABLE_THREADING_IN_BLOCK;
282     DECLARE_CONTEXT;
283     return engine_->hmat->get(i, j);
284 }
285 
286 // Explicit template instantiation
287 template class HMatInterface<S_t>;
288 template class HMatInterface<D_t>;
289 template class HMatInterface<C_t>;
290 template class HMatInterface<Z_t>;
291 
292 
293 } // end namespace hmat
294 
295 
296