1 /* =========================================================================
2    Copyright (c) 2010-2014, Institute for Microelectronics,
3                             Institute for Analysis and Scientific Computing,
4                             TU Wien.
5    Portions of this software are copyright by UChicago Argonne, LLC.
6 
7                             -----------------
8                   ViennaCL - The Vienna Computing Library
9                             -----------------
10 
11    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
12 
13    (A list of authors and contributors can be found in the PDF manual)
14 
15    License:         MIT (X11), see file LICENSE in the base directory
16 ============================================================================= */
17 
18 // include necessary system headers
19 #include <iostream>
20 
21 #include "viennacl.hpp"
22 #include "viennacl_private.hpp"
23 
24 #include "init_matrix.hpp"
25 
26 //include basic scalar and vector types of ViennaCL
27 #include "viennacl/scalar.hpp"
28 #include "viennacl/vector.hpp"
29 #include "viennacl/matrix.hpp"
30 #include "viennacl/linalg/direct_solve.hpp"
31 #include "viennacl/linalg/prod.hpp"
32 
33 // GEMV
34 
ViennaCLgemm(ViennaCLHostScalar alpha,ViennaCLMatrix A,ViennaCLMatrix B,ViennaCLHostScalar beta,ViennaCLMatrix C)35 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLgemm(ViennaCLHostScalar alpha, ViennaCLMatrix A, ViennaCLMatrix B, ViennaCLHostScalar beta, ViennaCLMatrix C)
36 {
37   viennacl::backend::mem_handle A_handle;
38   viennacl::backend::mem_handle B_handle;
39   viennacl::backend::mem_handle C_handle;
40 
41   if (init_matrix(A_handle, A) != ViennaCLSuccess)
42     return ViennaCLGenericFailure;
43 
44   if (init_matrix(B_handle, B) != ViennaCLSuccess)
45     return ViennaCLGenericFailure;
46 
47   if (init_matrix(C_handle, C) != ViennaCLSuccess)
48     return ViennaCLGenericFailure;
49 
50   switch (A->precision)
51   {
52     case ViennaCLFloat:
53     {
54       typedef viennacl::matrix_base<float>::size_type           size_type;
55       typedef viennacl::matrix_base<float>::size_type           difference_type;
56 
57       viennacl::matrix_base<float> mat_A(A_handle,
58                                          size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
59                                          size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
60       viennacl::matrix_base<float> mat_B(B_handle,
61                                          size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
62                                          size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
63       viennacl::matrix_base<float> mat_C(C_handle,
64                                          size_type(C->size1), size_type(C->start1), difference_type(C->stride1), size_type(C->internal_size1),
65                                          size_type(C->size2), size_type(C->start2), difference_type(C->stride2), size_type(C->internal_size2), C->order == ViennaCLRowMajor);
66 
67       if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
68         viennacl::linalg::prod_impl(viennacl::trans(mat_A), viennacl::trans(mat_B), mat_C, alpha->value_float, beta->value_float);
69       else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
70         viennacl::linalg::prod_impl(viennacl::trans(mat_A), mat_B, mat_C, alpha->value_float, beta->value_float);
71       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
72         viennacl::linalg::prod_impl(mat_A, viennacl::trans(mat_B), mat_C, alpha->value_float, beta->value_float);
73       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
74         viennacl::linalg::prod_impl(mat_A, mat_B, mat_C, alpha->value_float, beta->value_float);
75       else
76         return ViennaCLGenericFailure;
77 
78       return ViennaCLSuccess;
79     }
80 
81     case ViennaCLDouble:
82     {
83       typedef viennacl::matrix_base<double>::size_type           size_type;
84       typedef viennacl::matrix_base<double>::size_type           difference_type;
85 
86       viennacl::matrix_base<double> mat_A(A_handle,
87                                           size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
88                                           size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
89       viennacl::matrix_base<double> mat_B(B_handle,
90                                           size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
91                                           size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
92       viennacl::matrix_base<double> mat_C(C_handle,
93                                           size_type(C->size1), size_type(C->start1), difference_type(C->stride1), size_type(C->internal_size1),
94                                           size_type(C->size2), size_type(C->start2), difference_type(C->stride2), size_type(C->internal_size2), C->order == ViennaCLRowMajor);
95 
96       if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
97         viennacl::linalg::prod_impl(viennacl::trans(mat_A), viennacl::trans(mat_B), mat_C, alpha->value_double, beta->value_double);
98       else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
99         viennacl::linalg::prod_impl(viennacl::trans(mat_A), mat_B, mat_C, alpha->value_double, beta->value_double);
100       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
101         viennacl::linalg::prod_impl(mat_A, viennacl::trans(mat_B), mat_C, alpha->value_double, beta->value_double);
102       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
103         viennacl::linalg::prod_impl(mat_A, mat_B, mat_C, alpha->value_double, beta->value_double);
104       else
105         return ViennaCLGenericFailure;
106 
107       return ViennaCLSuccess;
108     }
109 
110     default:
111       return ViennaCLGenericFailure;
112   }
113 }
114 
115 
116 // xTRSV
117 
ViennaCLtrsm(ViennaCLMatrix A,ViennaCLUplo uplo,ViennaCLDiag diag,ViennaCLMatrix B)118 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLtrsm(ViennaCLMatrix A, ViennaCLUplo uplo, ViennaCLDiag diag, ViennaCLMatrix B)
119 {
120   viennacl::backend::mem_handle A_handle;
121   viennacl::backend::mem_handle B_handle;
122 
123   if (init_matrix(A_handle, A) != ViennaCLSuccess)
124     return ViennaCLGenericFailure;
125 
126   if (init_matrix(B_handle, B) != ViennaCLSuccess)
127     return ViennaCLGenericFailure;
128 
129   switch (A->precision)
130   {
131     case ViennaCLFloat:
132     {
133       typedef viennacl::matrix_base<float>::size_type           size_type;
134       typedef viennacl::matrix_base<float>::size_type           difference_type;
135 
136       viennacl::matrix_base<float> mat_A(A_handle,
137                                          size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
138                                          size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
139       viennacl::matrix_base<float> mat_B(B_handle,
140                                          size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
141                                          size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
142 
143       if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
144       {
145         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
146           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
147         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
148           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
149         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
150           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
151         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
152           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
153         else
154           return ViennaCLGenericFailure;
155       }
156       else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
157       {
158         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
159           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::upper_tag());
160         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
161           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_upper_tag());
162         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
163           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::lower_tag());
164         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
165           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_lower_tag());
166         else
167           return ViennaCLGenericFailure;
168       }
169       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
170       {
171         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
172           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
173         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
174           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
175         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
176           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
177         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
178           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
179         else
180           return ViennaCLGenericFailure;
181       }
182       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
183       {
184         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
185           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::upper_tag());
186         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
187           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_upper_tag());
188         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
189           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::lower_tag());
190         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
191           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_lower_tag());
192         else
193           return ViennaCLGenericFailure;
194       }
195 
196       return ViennaCLSuccess;
197     }
198     case ViennaCLDouble:
199     {
200       typedef viennacl::matrix_base<double>::size_type           size_type;
201       typedef viennacl::matrix_base<double>::size_type           difference_type;
202 
203       viennacl::matrix_base<double> mat_A(A_handle,
204                                           size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
205                                           size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
206       viennacl::matrix_base<double> mat_B(B_handle,
207                                           size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
208                                           size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
209 
210       if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
211       {
212         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
213           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
214         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
215           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
216         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
217           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
218         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
219           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
220         else
221           return ViennaCLGenericFailure;
222       }
223       else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
224       {
225         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
226           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::upper_tag());
227         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
228           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_upper_tag());
229         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
230           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::lower_tag());
231         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
232           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_lower_tag());
233         else
234           return ViennaCLGenericFailure;
235       }
236       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
237       {
238         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
239           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
240         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
241           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
242         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
243           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
244         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
245           viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
246         else
247           return ViennaCLGenericFailure;
248       }
249       else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
250       {
251         if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
252           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::upper_tag());
253         else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
254           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_upper_tag());
255         else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
256           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::lower_tag());
257         else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
258           viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_lower_tag());
259         else
260           return ViennaCLGenericFailure;
261       }
262 
263       return ViennaCLSuccess;
264     }
265 
266     default:
267       return  ViennaCLGenericFailure;
268   }
269 }
270 
271 
272 
273