1 /* =========================================================================
2 Copyright (c) 2010-2014, Institute for Microelectronics,
3 Institute for Analysis and Scientific Computing,
4 TU Wien.
5 Portions of this software are copyright by UChicago Argonne, LLC.
6
7 -----------------
8 ViennaCL - The Vienna Computing Library
9 -----------------
10
11 Project Head: Karl Rupp rupp@iue.tuwien.ac.at
12
13 (A list of authors and contributors can be found in the PDF manual)
14
15 License: MIT (X11), see file LICENSE in the base directory
16 ============================================================================= */
17
18 // include necessary system headers
19 #include <iostream>
20
21 #include "viennacl.hpp"
22 #include "viennacl_private.hpp"
23
24 #include "init_matrix.hpp"
25
26 //include basic scalar and vector types of ViennaCL
27 #include "viennacl/scalar.hpp"
28 #include "viennacl/vector.hpp"
29 #include "viennacl/matrix.hpp"
30 #include "viennacl/linalg/direct_solve.hpp"
31 #include "viennacl/linalg/prod.hpp"
32
33 // GEMV
34
ViennaCLgemm(ViennaCLHostScalar alpha,ViennaCLMatrix A,ViennaCLMatrix B,ViennaCLHostScalar beta,ViennaCLMatrix C)35 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLgemm(ViennaCLHostScalar alpha, ViennaCLMatrix A, ViennaCLMatrix B, ViennaCLHostScalar beta, ViennaCLMatrix C)
36 {
37 viennacl::backend::mem_handle A_handle;
38 viennacl::backend::mem_handle B_handle;
39 viennacl::backend::mem_handle C_handle;
40
41 if (init_matrix(A_handle, A) != ViennaCLSuccess)
42 return ViennaCLGenericFailure;
43
44 if (init_matrix(B_handle, B) != ViennaCLSuccess)
45 return ViennaCLGenericFailure;
46
47 if (init_matrix(C_handle, C) != ViennaCLSuccess)
48 return ViennaCLGenericFailure;
49
50 switch (A->precision)
51 {
52 case ViennaCLFloat:
53 {
54 typedef viennacl::matrix_base<float>::size_type size_type;
55 typedef viennacl::matrix_base<float>::size_type difference_type;
56
57 viennacl::matrix_base<float> mat_A(A_handle,
58 size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
59 size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
60 viennacl::matrix_base<float> mat_B(B_handle,
61 size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
62 size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
63 viennacl::matrix_base<float> mat_C(C_handle,
64 size_type(C->size1), size_type(C->start1), difference_type(C->stride1), size_type(C->internal_size1),
65 size_type(C->size2), size_type(C->start2), difference_type(C->stride2), size_type(C->internal_size2), C->order == ViennaCLRowMajor);
66
67 if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
68 viennacl::linalg::prod_impl(viennacl::trans(mat_A), viennacl::trans(mat_B), mat_C, alpha->value_float, beta->value_float);
69 else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
70 viennacl::linalg::prod_impl(viennacl::trans(mat_A), mat_B, mat_C, alpha->value_float, beta->value_float);
71 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
72 viennacl::linalg::prod_impl(mat_A, viennacl::trans(mat_B), mat_C, alpha->value_float, beta->value_float);
73 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
74 viennacl::linalg::prod_impl(mat_A, mat_B, mat_C, alpha->value_float, beta->value_float);
75 else
76 return ViennaCLGenericFailure;
77
78 return ViennaCLSuccess;
79 }
80
81 case ViennaCLDouble:
82 {
83 typedef viennacl::matrix_base<double>::size_type size_type;
84 typedef viennacl::matrix_base<double>::size_type difference_type;
85
86 viennacl::matrix_base<double> mat_A(A_handle,
87 size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
88 size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
89 viennacl::matrix_base<double> mat_B(B_handle,
90 size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
91 size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
92 viennacl::matrix_base<double> mat_C(C_handle,
93 size_type(C->size1), size_type(C->start1), difference_type(C->stride1), size_type(C->internal_size1),
94 size_type(C->size2), size_type(C->start2), difference_type(C->stride2), size_type(C->internal_size2), C->order == ViennaCLRowMajor);
95
96 if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
97 viennacl::linalg::prod_impl(viennacl::trans(mat_A), viennacl::trans(mat_B), mat_C, alpha->value_double, beta->value_double);
98 else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
99 viennacl::linalg::prod_impl(viennacl::trans(mat_A), mat_B, mat_C, alpha->value_double, beta->value_double);
100 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
101 viennacl::linalg::prod_impl(mat_A, viennacl::trans(mat_B), mat_C, alpha->value_double, beta->value_double);
102 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
103 viennacl::linalg::prod_impl(mat_A, mat_B, mat_C, alpha->value_double, beta->value_double);
104 else
105 return ViennaCLGenericFailure;
106
107 return ViennaCLSuccess;
108 }
109
110 default:
111 return ViennaCLGenericFailure;
112 }
113 }
114
115
116 // xTRSV
117
ViennaCLtrsm(ViennaCLMatrix A,ViennaCLUplo uplo,ViennaCLDiag diag,ViennaCLMatrix B)118 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLtrsm(ViennaCLMatrix A, ViennaCLUplo uplo, ViennaCLDiag diag, ViennaCLMatrix B)
119 {
120 viennacl::backend::mem_handle A_handle;
121 viennacl::backend::mem_handle B_handle;
122
123 if (init_matrix(A_handle, A) != ViennaCLSuccess)
124 return ViennaCLGenericFailure;
125
126 if (init_matrix(B_handle, B) != ViennaCLSuccess)
127 return ViennaCLGenericFailure;
128
129 switch (A->precision)
130 {
131 case ViennaCLFloat:
132 {
133 typedef viennacl::matrix_base<float>::size_type size_type;
134 typedef viennacl::matrix_base<float>::size_type difference_type;
135
136 viennacl::matrix_base<float> mat_A(A_handle,
137 size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
138 size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
139 viennacl::matrix_base<float> mat_B(B_handle,
140 size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
141 size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
142
143 if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
144 {
145 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
146 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
147 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
148 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
149 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
150 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
151 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
152 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
153 else
154 return ViennaCLGenericFailure;
155 }
156 else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
157 {
158 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
159 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::upper_tag());
160 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
161 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_upper_tag());
162 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
163 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::lower_tag());
164 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
165 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_lower_tag());
166 else
167 return ViennaCLGenericFailure;
168 }
169 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
170 {
171 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
172 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
173 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
174 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
175 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
176 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
177 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
178 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
179 else
180 return ViennaCLGenericFailure;
181 }
182 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
183 {
184 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
185 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::upper_tag());
186 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
187 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_upper_tag());
188 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
189 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::lower_tag());
190 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
191 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_lower_tag());
192 else
193 return ViennaCLGenericFailure;
194 }
195
196 return ViennaCLSuccess;
197 }
198 case ViennaCLDouble:
199 {
200 typedef viennacl::matrix_base<double>::size_type size_type;
201 typedef viennacl::matrix_base<double>::size_type difference_type;
202
203 viennacl::matrix_base<double> mat_A(A_handle,
204 size_type(A->size1), size_type(A->start1), difference_type(A->stride1), size_type(A->internal_size1),
205 size_type(A->size2), size_type(A->start2), difference_type(A->stride2), size_type(A->internal_size2), A->order == ViennaCLRowMajor);
206 viennacl::matrix_base<double> mat_B(B_handle,
207 size_type(B->size1), size_type(B->start1), difference_type(B->stride1), size_type(B->internal_size1),
208 size_type(B->size2), size_type(B->start2), difference_type(B->stride2), size_type(B->internal_size2), B->order == ViennaCLRowMajor);
209
210 if (A->trans == ViennaCLTrans && B->trans == ViennaCLTrans)
211 {
212 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
213 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
214 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
215 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
216 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
217 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
218 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
219 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
220 else
221 return ViennaCLGenericFailure;
222 }
223 else if (A->trans == ViennaCLTrans && B->trans == ViennaCLNoTrans)
224 {
225 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
226 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::upper_tag());
227 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
228 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_upper_tag());
229 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
230 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::lower_tag());
231 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
232 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), mat_B, viennacl::linalg::unit_lower_tag());
233 else
234 return ViennaCLGenericFailure;
235 }
236 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLTrans)
237 {
238 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
239 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::upper_tag());
240 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
241 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_upper_tag());
242 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
243 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::lower_tag());
244 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
245 viennacl::linalg::inplace_solve(viennacl::trans(mat_A), viennacl::trans(mat_B), viennacl::linalg::unit_lower_tag());
246 else
247 return ViennaCLGenericFailure;
248 }
249 else if (A->trans == ViennaCLNoTrans && B->trans == ViennaCLNoTrans)
250 {
251 if (uplo == ViennaCLUpper && diag == ViennaCLNonUnit)
252 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::upper_tag());
253 else if (uplo == ViennaCLUpper && diag == ViennaCLUnit)
254 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_upper_tag());
255 else if (uplo == ViennaCLLower && diag == ViennaCLNonUnit)
256 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::lower_tag());
257 else if (uplo == ViennaCLLower && diag == ViennaCLUnit)
258 viennacl::linalg::inplace_solve(mat_A, mat_B, viennacl::linalg::unit_lower_tag());
259 else
260 return ViennaCLGenericFailure;
261 }
262
263 return ViennaCLSuccess;
264 }
265
266 default:
267 return ViennaCLGenericFailure;
268 }
269 }
270
271
272
273