1 /* =========================================================================
2    Copyright (c) 2010-2014, Institute for Microelectronics,
3                             Institute for Analysis and Scientific Computing,
4                             TU Wien.
5    Portions of this software are copyright by UChicago Argonne, LLC.
6 
7                             -----------------
8                   ViennaCL - The Vienna Computing Library
9                             -----------------
10 
11    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
12 
13    (A list of authors and contributors can be found in the PDF manual)
14 
15    License:         MIT (X11), see file LICENSE in the base directory
16 ============================================================================= */
17 
18 // include necessary system headers
19 #include <iostream>
20 
21 #include "viennacl.hpp"
22 #include "viennacl_private.hpp"
23 
24 #include "blas3.hpp"
25 
26 //include basic scalar and vector types of ViennaCL
27 #include "viennacl/scalar.hpp"
28 #include "viennacl/vector.hpp"
29 #include "viennacl/matrix.hpp"
30 #include "viennacl/linalg/direct_solve.hpp"
31 #include "viennacl/linalg/prod.hpp"
32 
33 
34 #ifdef VIENNACL_WITH_CUDA
35 
36 
37 
38 //
39 // xGEMV
40 //
41 
42 namespace detail
43 {
44   template <typename NumericT>
ViennaCLCUDAgemm_impl(ViennaCLBackend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,NumericT alpha,NumericT * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,NumericT * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,NumericT beta,NumericT * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)45   ViennaCLStatus ViennaCLCUDAgemm_impl(ViennaCLBackend /*backend*/,
46                                        ViennaCLOrder orderA, ViennaCLTranspose transA,
47                                        ViennaCLOrder orderB, ViennaCLTranspose transB,
48                                        ViennaCLOrder orderC,
49                                        ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
50                                        NumericT alpha,
51                                        NumericT *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
52                                        NumericT *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
53                                        NumericT beta,
54                                        NumericT *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
55   {
56     ViennaCLInt A_size1 = (transA == ViennaCLTrans) ? k : m;
57     ViennaCLInt A_size2 = (transA == ViennaCLTrans) ? m : k;
58 
59     ViennaCLInt B_size1 = (transB == ViennaCLTrans) ? n : k;
60     ViennaCLInt B_size2 = (transB == ViennaCLTrans) ? k : n;
61 
62     bool A_row_major = (orderA == ViennaCLRowMajor);
63     bool B_row_major = (orderB == ViennaCLRowMajor);
64     bool C_row_major = (orderC == ViennaCLRowMajor);
65 
66     viennacl::matrix_base<NumericT> matA(A, viennacl::CUDA_MEMORY,
67                                          A_size1, offA_row, incA_row, A_row_major ? m : lda,
68                                          A_size2, offA_col, incA_col, A_row_major ? lda : k, A_row_major);
69 
70     viennacl::matrix_base<NumericT> matB(B, viennacl::CUDA_MEMORY,
71                                          B_size1, offB_row, incB_row, B_row_major ? k : ldb,
72                                          B_size2, offB_col, incB_col, B_row_major ? ldb : n, B_row_major);
73 
74     viennacl::matrix_base<NumericT> matC(C, viennacl::CUDA_MEMORY,
75                                          m, offC_row, incC_row, C_row_major ? m : ldc,
76                                          n, offC_col, incC_col, C_row_major ? ldc : n, C_row_major);
77 
78     detail::gemm_dispatch(alpha, matA, transA, matB, transB, beta, matC);
79 
80     return ViennaCLSuccess;
81   }
82 
83 }
84 
85 
ViennaCLCUDASgemm(ViennaCLBackend backend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,float alpha,float * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,float * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,float beta,float * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)86 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLCUDASgemm(ViennaCLBackend backend,
87                                                             ViennaCLOrder orderA, ViennaCLTranspose transA,
88                                                             ViennaCLOrder orderB, ViennaCLTranspose transB,
89                                                             ViennaCLOrder orderC,
90                                                             ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
91                                                             float alpha,
92                                                             float *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
93                                                             float *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
94                                                             float beta,
95                                                             float *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
96 {
97   return detail::ViennaCLCUDAgemm_impl<float>(backend,
98                                               orderA, transA,
99                                               orderB, transB,
100                                               orderC,
101                                               m, n, k,
102                                               alpha,
103                                               A, offA_row, offA_col, incA_row, incA_col, lda,
104                                               B, offB_row, offB_col, incB_row, incB_col, ldb,
105                                               beta,
106                                               C, offC_row, offC_col, incC_row, incC_col, ldc);
107 }
108 
ViennaCLCUDADgemm(ViennaCLBackend backend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,double alpha,double * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,double * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,double beta,double * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)109 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLCUDADgemm(ViennaCLBackend backend,
110                                                             ViennaCLOrder orderA, ViennaCLTranspose transA,
111                                                             ViennaCLOrder orderB, ViennaCLTranspose transB,
112                                                             ViennaCLOrder orderC,
113                                                             ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
114                                                             double alpha,
115                                                             double *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
116                                                             double *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
117                                                             double beta,
118                                                             double *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
119 {
120   return detail::ViennaCLCUDAgemm_impl<double>(backend,
121                                                orderA, transA,
122                                                orderB, transB,
123                                                orderC,
124                                                m, n, k,
125                                                alpha,
126                                                A, offA_row, offA_col, incA_row, incA_col, lda,
127                                                B, offB_row, offB_col, incB_row, incB_col, ldb,
128                                                beta,
129                                                C, offC_row, offC_col, incC_row, incC_col, ldc);
130 }
131 
132 
133 #endif
134