1 /* =========================================================================
2 Copyright (c) 2010-2014, Institute for Microelectronics,
3 Institute for Analysis and Scientific Computing,
4 TU Wien.
5 Portions of this software are copyright by UChicago Argonne, LLC.
6
7 -----------------
8 ViennaCL - The Vienna Computing Library
9 -----------------
10
11 Project Head: Karl Rupp rupp@iue.tuwien.ac.at
12
13 (A list of authors and contributors can be found in the PDF manual)
14
15 License: MIT (X11), see file LICENSE in the base directory
16 ============================================================================= */
17
18 // include necessary system headers
19 #include <iostream>
20
21 #include "viennacl.hpp"
22 #include "viennacl_private.hpp"
23
24 #include "blas3.hpp"
25
26 //include basic scalar and vector types of ViennaCL
27 #include "viennacl/scalar.hpp"
28 #include "viennacl/vector.hpp"
29 #include "viennacl/matrix.hpp"
30 #include "viennacl/linalg/direct_solve.hpp"
31 #include "viennacl/linalg/prod.hpp"
32
33
34 #ifdef VIENNACL_WITH_CUDA
35
36
37
38 //
39 // xGEMV
40 //
41
42 namespace detail
43 {
44 template <typename NumericT>
ViennaCLCUDAgemm_impl(ViennaCLBackend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,NumericT alpha,NumericT * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,NumericT * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,NumericT beta,NumericT * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)45 ViennaCLStatus ViennaCLCUDAgemm_impl(ViennaCLBackend /*backend*/,
46 ViennaCLOrder orderA, ViennaCLTranspose transA,
47 ViennaCLOrder orderB, ViennaCLTranspose transB,
48 ViennaCLOrder orderC,
49 ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
50 NumericT alpha,
51 NumericT *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
52 NumericT *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
53 NumericT beta,
54 NumericT *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
55 {
56 ViennaCLInt A_size1 = (transA == ViennaCLTrans) ? k : m;
57 ViennaCLInt A_size2 = (transA == ViennaCLTrans) ? m : k;
58
59 ViennaCLInt B_size1 = (transB == ViennaCLTrans) ? n : k;
60 ViennaCLInt B_size2 = (transB == ViennaCLTrans) ? k : n;
61
62 bool A_row_major = (orderA == ViennaCLRowMajor);
63 bool B_row_major = (orderB == ViennaCLRowMajor);
64 bool C_row_major = (orderC == ViennaCLRowMajor);
65
66 viennacl::matrix_base<NumericT> matA(A, viennacl::CUDA_MEMORY,
67 A_size1, offA_row, incA_row, A_row_major ? m : lda,
68 A_size2, offA_col, incA_col, A_row_major ? lda : k, A_row_major);
69
70 viennacl::matrix_base<NumericT> matB(B, viennacl::CUDA_MEMORY,
71 B_size1, offB_row, incB_row, B_row_major ? k : ldb,
72 B_size2, offB_col, incB_col, B_row_major ? ldb : n, B_row_major);
73
74 viennacl::matrix_base<NumericT> matC(C, viennacl::CUDA_MEMORY,
75 m, offC_row, incC_row, C_row_major ? m : ldc,
76 n, offC_col, incC_col, C_row_major ? ldc : n, C_row_major);
77
78 detail::gemm_dispatch(alpha, matA, transA, matB, transB, beta, matC);
79
80 return ViennaCLSuccess;
81 }
82
83 }
84
85
ViennaCLCUDASgemm(ViennaCLBackend backend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,float alpha,float * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,float * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,float beta,float * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)86 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLCUDASgemm(ViennaCLBackend backend,
87 ViennaCLOrder orderA, ViennaCLTranspose transA,
88 ViennaCLOrder orderB, ViennaCLTranspose transB,
89 ViennaCLOrder orderC,
90 ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
91 float alpha,
92 float *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
93 float *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
94 float beta,
95 float *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
96 {
97 return detail::ViennaCLCUDAgemm_impl<float>(backend,
98 orderA, transA,
99 orderB, transB,
100 orderC,
101 m, n, k,
102 alpha,
103 A, offA_row, offA_col, incA_row, incA_col, lda,
104 B, offB_row, offB_col, incB_row, incB_col, ldb,
105 beta,
106 C, offC_row, offC_col, incC_row, incC_col, ldc);
107 }
108
ViennaCLCUDADgemm(ViennaCLBackend backend,ViennaCLOrder orderA,ViennaCLTranspose transA,ViennaCLOrder orderB,ViennaCLTranspose transB,ViennaCLOrder orderC,ViennaCLInt m,ViennaCLInt n,ViennaCLInt k,double alpha,double * A,ViennaCLInt offA_row,ViennaCLInt offA_col,ViennaCLInt incA_row,ViennaCLInt incA_col,ViennaCLInt lda,double * B,ViennaCLInt offB_row,ViennaCLInt offB_col,ViennaCLInt incB_row,ViennaCLInt incB_col,ViennaCLInt ldb,double beta,double * C,ViennaCLInt offC_row,ViennaCLInt offC_col,ViennaCLInt incC_row,ViennaCLInt incC_col,ViennaCLInt ldc)109 VIENNACL_EXPORTED_FUNCTION ViennaCLStatus ViennaCLCUDADgemm(ViennaCLBackend backend,
110 ViennaCLOrder orderA, ViennaCLTranspose transA,
111 ViennaCLOrder orderB, ViennaCLTranspose transB,
112 ViennaCLOrder orderC,
113 ViennaCLInt m, ViennaCLInt n, ViennaCLInt k,
114 double alpha,
115 double *A, ViennaCLInt offA_row, ViennaCLInt offA_col, ViennaCLInt incA_row, ViennaCLInt incA_col, ViennaCLInt lda,
116 double *B, ViennaCLInt offB_row, ViennaCLInt offB_col, ViennaCLInt incB_row, ViennaCLInt incB_col, ViennaCLInt ldb,
117 double beta,
118 double *C, ViennaCLInt offC_row, ViennaCLInt offC_col, ViennaCLInt incC_row, ViennaCLInt incC_col, ViennaCLInt ldc)
119 {
120 return detail::ViennaCLCUDAgemm_impl<double>(backend,
121 orderA, transA,
122 orderB, transB,
123 orderC,
124 m, n, k,
125 alpha,
126 A, offA_row, offA_col, incA_row, incA_col, lda,
127 B, offB_row, offB_col, incB_row, incB_col, ldb,
128 beta,
129 C, offC_row, offC_col, incC_row, incC_col, ldc);
130 }
131
132
133 #endif
134