1 #ifndef __VMML__VMMLIB_BLAS_DGEMM__HPP__ 2 #define __VMML__VMMLIB_BLAS_DGEMM__HPP__ 3 4 5 #include <vmmlib/matrix.hpp> 6 #include <vmmlib/tensor3.hpp> 7 #include <vmmlib/exception.hpp> 8 #include <vmmlib/blas_includes.hpp> 9 #include <vmmlib/blas_types.hpp> 10 11 /** 12 * 13 * a wrapper for blas's DGEMM routine. 14 15 SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) 16 * .. Scalar Arguments .. 17 DOUBLE PRECISION ALPHA,BETA 18 INTEGER K,LDA,LDB,LDC,M,N 19 CHARACTER TRANSA,TRANSB 20 * .. 21 * .. Array Arguments .. 22 DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) 23 * .. 24 * 25 * Purpose 26 * ======= 27 * 28 * DGEMM performs one of the matrix-matrix operations 29 * 30 * C := alpha*op( A )*op( B ) + beta*C, 31 * 32 * where op( X ) is one of 33 * 34 * op( X ) = X or op( X ) = X**T, 35 * 36 * alpha and beta are scalars, and A, B and C are matrices, with op( A ) 37 * an m by k matrix, op( B ) a k by n matrix and C an m by n matrix. 38 * 39 * 40 * more information in: http://www.netlib.org/blas/dgemm.f 41 * or http://www.netlib.org/clapack/cblas/dgemm.c 42 ** 43 */ 44 45 46 namespace vmml 47 { 48 49 namespace blas 50 { 51 52 53 #if 0 54 /* Subroutine */ 55 void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, 56 blasint M, blasint N, blasint K, 57 double alpha, double *A, blasint lda, double *B, blasint ldb, double beta, double *C, blasint ldc); 58 59 #endif 60 61 template< typename float_t > 62 struct dgemm_params 63 { 64 CBLAS_ORDER order; 65 CBLAS_TRANSPOSE trans_a; 66 CBLAS_TRANSPOSE trans_b; 67 blas_int m; 68 blas_int n; 69 blas_int k; 70 float_t alpha; 71 float_t* a; 72 blas_int lda; //leading dimension of input array matrix left 73 float_t* b; 74 blas_int ldb; //leading dimension of input array matrix right 75 float_t beta; 76 float_t* c; 77 blas_int ldc; //leading dimension of output array matrix right 78 operator <<(std::ostream & os,const dgemm_params<float_t> & p)79 friend std::ostream& operator << ( std::ostream& os, 80 const dgemm_params< float_t >& p ) 81 { 82 os 83 << " (1)\torder " << p.order << std::endl 84 << " (2)\ttrans_a " << p.trans_a << std::endl 85 << " (3)\ttrans_b " << p.trans_b << std::endl 86 << " (4)\tm " << p.m << std::endl 87 << " (6)\tn " << p.n << std::endl 88 << " (5)\tk " << p.k << std::endl 89 << " (7)\talpha " << p.alpha << std::endl 90 << " (8)\ta " << p.a << std::endl 91 << " (9)\tlda " << p.lda << std::endl 92 << " (10)\tb " << p.b << std::endl 93 << " (11)\tldb " << p.ldb << std::endl 94 << " (12)\tbeta " << p.beta << std::endl 95 << " (13)\tc " << p.c << std::endl 96 << " (14)\tldc " << p.ldc << std::endl 97 << std::endl; 98 return os; 99 } 100 101 }; 102 103 104 105 template< typename float_t > 106 inline void dgemm_call(dgemm_params<float_t> & p)107 dgemm_call( dgemm_params< float_t >& p ) 108 { 109 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE ); 110 } 111 112 113 template<> 114 inline void dgemm_call(dgemm_params<float> & p)115 dgemm_call( dgemm_params< float >& p ) 116 { 117 //std::cout << "calling blas sgemm (single precision) " << std::endl; 118 cblas_sgemm( 119 p.order, 120 p.trans_a, 121 p.trans_b, 122 p.m, 123 p.n, 124 p.k, 125 p.alpha, 126 p.a, 127 p.lda, 128 p.b, 129 p.ldb, 130 p.beta, 131 p.c, 132 p.ldc 133 ); 134 135 } 136 137 template<> 138 inline void dgemm_call(dgemm_params<double> & p)139 dgemm_call( dgemm_params< double >& p ) 140 { 141 //std::cout << "calling blas dgemm (double precision) " << std::endl; 142 cblas_dgemm( 143 p.order, 144 p.trans_a, 145 p.trans_b, 146 p.m, 147 p.n, 148 p.k, 149 p.alpha, 150 p.a, 151 p.lda, 152 p.b, 153 p.ldb, 154 p.beta, 155 p.c, 156 p.ldc 157 ); 158 } 159 160 } // namespace blas 161 162 163 164 template< size_t M, size_t K, size_t N, typename float_t > 165 struct blas_dgemm 166 { 167 168 typedef matrix< M, K, float_t > matrix_left_t; 169 typedef matrix< K, M, float_t > matrix_left_t_t; 170 typedef matrix< K, N, float_t > matrix_right_t; 171 typedef matrix< N, K, float_t > matrix_right_t_t; 172 typedef matrix< M, N, float_t > matrix_out_t; 173 typedef vector< M, float_t > vector_left_t; 174 typedef vector< N, float_t > vector_right_t; 175 176 blas_dgemm(); ~blas_dgemmvmml::blas_dgemm177 ~blas_dgemm() {}; 178 179 bool compute( const matrix_left_t& A_, const matrix_right_t& B_, matrix_out_t& C_ ); 180 bool compute( const matrix_left_t& A_, matrix_out_t& C_ ); 181 182 // dgemms with tensor3 input works for frontal tensor unfolding 183 //I2*I3 = K; 184 template< size_t I2, size_t I3 > 185 bool compute( const tensor3< M, I2, I3, float_t >& A_, const matrix_right_t& B_, matrix_out_t& C_ ); 186 //I2*I3 = K; 187 template< size_t I2, size_t I3 > 188 bool compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ ); 189 190 bool compute_t( const matrix_right_t& B_, matrix_out_t& C_ ); 191 bool compute_bt( const matrix_left_t& A_, const matrix_right_t_t& Bt_, matrix_out_t& C_ ); 192 bool compute_t( const matrix_left_t_t& A_, const matrix_right_t_t& B_, matrix_out_t& C_ ); 193 bool compute_vv_outer( const vector_left_t& A_, const vector_right_t& B_, matrix_out_t& C_ ); 194 195 196 blas::dgemm_params< float_t > p; 197 get_paramsvmml::blas_dgemm198 const blas::dgemm_params< float_t >& get_params(){ return p; }; 199 200 201 }; // struct blas_dgemm 202 203 204 template< size_t M, size_t K, size_t N, typename float_t > blas_dgemm()205 blas_dgemm< M, K, N, float_t >::blas_dgemm() 206 { 207 p.order = CblasColMajor; // 208 p.trans_a = CblasNoTrans; 209 p.trans_b = CblasNoTrans; 210 p.m = M; 211 p.n = N; 212 p.k = K; 213 p.alpha = 1; 214 p.a = 0; 215 p.lda = M; 216 p.b = 0; 217 p.ldb = K; //no transpose 218 p.beta = 0; 219 p.c = 0; 220 p.ldc = M; 221 } 222 223 224 225 template< size_t M, size_t K, size_t N, typename float_t > 226 bool compute(const matrix_left_t & A_,const matrix_right_t & B_,matrix_out_t & C_)227 blas_dgemm< M, K, N, float_t >::compute( 228 const matrix_left_t& A_, 229 const matrix_right_t& B_, 230 matrix_out_t& C_ 231 ) 232 { 233 // blas needs non-const data 234 matrix_left_t* AA = new matrix_left_t( A_ ); 235 matrix_right_t* BB = new matrix_right_t( B_ ); 236 C_.zero(); 237 238 p.a = AA->array; 239 p.b = BB->array; 240 p.c = C_.array; 241 242 blas::dgemm_call< float_t >( p ); 243 244 //std::cout << p << std::endl; //debug 245 246 delete AA; 247 delete BB; 248 249 return true; 250 } 251 252 template< size_t M, size_t K, size_t N, typename float_t > 253 template< size_t I2, size_t I3 > 254 bool compute(const tensor3<M,I2,I3,float_t> & A_,const matrix_right_t & B_,matrix_out_t & C_)255 blas_dgemm< M, K, N, float_t >::compute( 256 const tensor3< M, I2, I3, float_t >& A_, 257 const matrix_right_t& B_, 258 matrix_out_t& C_ 259 ) 260 { 261 // blas needs non-const data 262 tensor3< M, I2, I3, float_t > AA( A_ ); 263 matrix_right_t* BB = new matrix_right_t( B_ ); 264 C_.zero(); 265 266 p.a = AA.get_array_ptr(); 267 p.b = BB->array; 268 p.c = C_.array; 269 270 blas::dgemm_call< float_t >( p ); 271 272 //std::cout << p << std::endl; //debug 273 274 delete BB; 275 276 return true; 277 } 278 279 280 template< size_t M, size_t K, size_t N, typename float_t > 281 bool compute(const matrix_left_t & A_,matrix_out_t & C_)282 blas_dgemm< M, K, N, float_t >::compute( const matrix_left_t& A_, matrix_out_t& C_ ) 283 { 284 // blas needs non-const data 285 matrix_left_t* AA = new matrix_left_t( A_ ); 286 C_.zero(); 287 288 p.trans_b = CblasTrans; 289 p.a = AA->array; 290 p.b = AA->array; 291 p.ldb = N; 292 p.c = C_.array; 293 294 blas::dgemm_call< float_t >( p ); 295 296 //std::cout << p << std::endl; //debug 297 298 delete AA; 299 300 return true; 301 } 302 303 template< size_t M, size_t K, size_t N, typename float_t > 304 template< size_t I2, size_t I3 > 305 bool compute(const tensor3<M,I2,I3,float_t> & A_,matrix_out_t & C_)306 blas_dgemm< M, K, N, float_t >::compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ ) 307 { 308 // blas needs non-const data 309 tensor3< M, I2, I3, float_t > AA( A_ ) ; 310 C_.zero(); 311 312 p.trans_b = CblasTrans; 313 p.a = AA.get_array_ptr(); 314 p.b = AA.get_array_ptr(); 315 p.ldb = N; 316 p.c = C_.array; 317 318 blas::dgemm_call< float_t >( p ); 319 320 //std::cout << p << std::endl; //debug 321 322 return true; 323 } 324 325 template< size_t M, size_t K, size_t N, typename float_t > 326 bool compute_t(const matrix_right_t & B_,matrix_out_t & C_)327 blas_dgemm< M, K, N, float_t >::compute_t( const matrix_right_t& B_, matrix_out_t& C_ ) 328 { 329 // blas needs non-const data 330 matrix_right_t* BB = new matrix_right_t( B_ ); 331 C_.zero(); 332 333 p.trans_a = CblasTrans; 334 p.a = BB->array; 335 p.b = BB->array; 336 p.lda = K; 337 p.c = C_.array; 338 339 blas::dgemm_call< float_t >( p ); 340 341 //std::cout << p << std::endl; //debug 342 343 delete BB; 344 345 return true; 346 } 347 348 template< size_t M, size_t K, size_t N, typename float_t > 349 bool compute_bt(const matrix_left_t & A_,const matrix_right_t_t & Bt_,matrix_out_t & C_)350 blas_dgemm< M, K, N, float_t >::compute_bt( 351 const matrix_left_t& A_, 352 const matrix_right_t_t& Bt_, 353 matrix_out_t& C_ ) 354 { 355 // blas needs non-const data 356 matrix_left_t* AA = new matrix_left_t( A_ ); 357 matrix_right_t_t* BB = new matrix_right_t_t( Bt_ ); 358 C_.zero(); 359 360 p.trans_b = CblasTrans; 361 p.a = AA->array; 362 p.b = BB->array; 363 p.c = C_.array; 364 p.ldb = N; 365 366 blas::dgemm_call< float_t >( p ); 367 368 //std::cout << p << std::endl; //debug 369 370 delete AA; 371 delete BB; 372 373 return true; 374 } 375 376 template< size_t M, size_t K, size_t N, typename float_t > 377 bool compute_t(const matrix_left_t_t & At_,const matrix_right_t_t & Bt_,matrix_out_t & C_)378 blas_dgemm< M, K, N, float_t >::compute_t( 379 const matrix_left_t_t& At_, 380 const matrix_right_t_t& Bt_, 381 matrix_out_t& C_ ) 382 { 383 // blas needs non-const data 384 matrix_left_t_t* AA = new matrix_left_t_t( At_ ); 385 matrix_right_t_t* BB = new matrix_right_t_t( Bt_ ); 386 C_.zero(); 387 388 p.trans_a = CblasTrans; 389 p.trans_b = CblasTrans; 390 p.a = AA->array; 391 p.b = BB->array; 392 p.c = C_.array; 393 p.ldb = N; 394 p.lda = K; 395 396 blas::dgemm_call< float_t >( p ); 397 398 //std::cout << p << std::endl; //debug 399 400 delete AA; 401 delete BB; 402 403 return true; 404 } 405 406 template< size_t M, size_t K, size_t N, typename float_t > 407 bool compute_vv_outer(const vector_left_t & A_,const vector_right_t & B_,matrix_out_t & C_)408 blas_dgemm< M, K, N, float_t >::compute_vv_outer( 409 const vector_left_t& A_, 410 const vector_right_t& B_, 411 matrix_out_t& C_ ) 412 { 413 // blas needs non-const data 414 vector_left_t* AA = new vector_left_t( A_ ); 415 vector_right_t* BB = new vector_right_t( B_ ); 416 C_.zero(); 417 418 p.trans_a = CblasTrans; 419 p.a = AA->array; 420 p.b = BB->array; 421 p.c = C_.array; 422 p.lda = K; 423 424 blas::dgemm_call< float_t >( p ); 425 426 //std::cout << p << std::endl; //debug 427 428 delete AA; 429 delete BB; 430 431 return true; 432 } 433 434 435 } // namespace vmml 436 437 #endif 438 439