1 #ifndef __VMML__VMMLIB_BLAS_DGEMM__HPP__
2 #define __VMML__VMMLIB_BLAS_DGEMM__HPP__
3 
4 
5 #include <vmmlib/matrix.hpp>
6 #include <vmmlib/tensor3.hpp>
7 #include <vmmlib/exception.hpp>
8 #include <vmmlib/blas_includes.hpp>
9 #include <vmmlib/blas_types.hpp>
10 
11 /**
12  *
13  *   a wrapper for blas's DGEMM routine.
14 
15  SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
16  *     .. Scalar Arguments ..
17  DOUBLE PRECISION ALPHA,BETA
18  INTEGER K,LDA,LDB,LDC,M,N
19  CHARACTER TRANSA,TRANSB
20  *     ..
21  *     .. Array Arguments ..
22  DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
23  *     ..
24  *
25  *  Purpose
26  *  =======
27  *
28  *  DGEMM  performs one of the matrix-matrix operations
29  *
30  *     C := alpha*op( A )*op( B ) + beta*C,
31  *
32  *  where  op( X ) is one of
33  *
34  *     op( X ) = X   or   op( X ) = X**T,
35  *
36  *  alpha and beta are scalars, and A, B and C are matrices, with op( A )
37  *  an m by k matrix,  op( B )  a  k by n matrix and  C an m by n matrix.
38  *
39  *
40  *   more information in: http://www.netlib.org/blas/dgemm.f
41  *   or http://www.netlib.org/clapack/cblas/dgemm.c
42  **
43  */
44 
45 
46 namespace vmml
47 {
48 
49 	namespace blas
50 	{
51 
52 
53 #if 0
54 		/* Subroutine */
55 		void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
56 						 blasint M, blasint N, blasint K,
57 						 double alpha, double *A, blasint lda, double *B, blasint ldb, double beta, double *C, blasint ldc);
58 
59 #endif
60 
61 		template< typename float_t >
62 		struct dgemm_params
63 		{
64 			CBLAS_ORDER     order;
65 			CBLAS_TRANSPOSE trans_a;
66 			CBLAS_TRANSPOSE trans_b;
67 			blas_int 		m;
68 			blas_int		n;
69 			blas_int		k;
70 			float_t			alpha;
71 			float_t*        a;
72 			blas_int        lda; //leading dimension of input array matrix left
73 			float_t*        b;
74 			blas_int        ldb; //leading dimension of input array matrix right
75 			float_t			beta;
76 			float_t*        c;
77 			blas_int        ldc; //leading dimension of output array matrix right
78 
operator <<(std::ostream & os,const dgemm_params<float_t> & p)79 			friend std::ostream& operator << ( std::ostream& os,
80 											  const dgemm_params< float_t >& p )
81 			{
82 				os
83 				<< " (1)\torder "     << p.order << std::endl
84 				<< " (2)\ttrans_a "    << p.trans_a << std::endl
85 				<< " (3)\ttrans_b "     << p.trans_b << std::endl
86 				<< " (4)\tm "        << p.m << std::endl
87 				<< " (6)\tn "      << p.n << std::endl
88 				<< " (5)\tk "        << p.k << std::endl
89 				<< " (7)\talpha "       << p.alpha << std::endl
90 				<< " (8)\ta "       << p.a << std::endl
91 				<< " (9)\tlda "       << p.lda << std::endl
92 				<< " (10)\tb "       << p.b << std::endl
93 				<< " (11)\tldb "   << p.ldb << std::endl
94 				<< " (12)\tbeta "        << p.beta << std::endl
95 				<< " (13)\tc "        << p.c << std::endl
96 				<< " (14)\tldc "        << p.ldc << std::endl
97 				<< std::endl;
98 				return os;
99 			}
100 
101 		};
102 
103 
104 
105 		template< typename float_t >
106 		inline void
dgemm_call(dgemm_params<float_t> & p)107 		dgemm_call( dgemm_params< float_t >& p )
108 		{
109 			VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
110 		}
111 
112 
113 		template<>
114 		inline void
dgemm_call(dgemm_params<float> & p)115 		dgemm_call( dgemm_params< float >& p )
116 		{
117 			//std::cout << "calling blas sgemm (single precision) " << std::endl;
118 			cblas_sgemm(
119 					p.order,
120 					p.trans_a,
121 					p.trans_b,
122 					p.m,
123 					p.n,
124 					p.k,
125 					p.alpha,
126 					p.a,
127 					p.lda,
128 					p.b,
129 					p.ldb,
130 				    p.beta,
131 					p.c,
132 					p.ldc
133 					);
134 
135 		}
136 
137 		template<>
138 		inline void
dgemm_call(dgemm_params<double> & p)139 		dgemm_call( dgemm_params< double >& p )
140 		{
141 			//std::cout << "calling blas dgemm (double precision) " << std::endl;
142 			cblas_dgemm(
143 				   p.order,
144 				   p.trans_a,
145 				   p.trans_b,
146 				   p.m,
147 				   p.n,
148 				   p.k,
149 				   p.alpha,
150 				   p.a,
151 				   p.lda,
152 				   p.b,
153 				   p.ldb,
154 				   p.beta,
155 				   p.c,
156 				   p.ldc
157 				   );
158 		}
159 
160 	} // namespace blas
161 
162 
163 
164 	template< size_t M, size_t K, size_t N, typename float_t >
165 	struct blas_dgemm
166 	{
167 
168 		typedef matrix< M, K, float_t > matrix_left_t;
169 		typedef matrix< K, M, float_t > matrix_left_t_t;
170 		typedef matrix< K, N, float_t > matrix_right_t;
171 		typedef matrix< N, K, float_t > matrix_right_t_t;
172 		typedef matrix< M, N, float_t > matrix_out_t;
173 		typedef vector< M, float_t > vector_left_t;
174 		typedef vector< N, float_t > vector_right_t;
175 
176 		blas_dgemm();
~blas_dgemmvmml::blas_dgemm177 		~blas_dgemm() {};
178 
179 		bool compute( const matrix_left_t& A_, const matrix_right_t& B_, matrix_out_t& C_ );
180 		bool compute( const matrix_left_t& A_, matrix_out_t& C_ );
181 
182 		// dgemms with tensor3 input works for frontal tensor unfolding
183 		//I2*I3 = K;
184 		template< size_t I2, size_t I3 >
185 		bool compute( const tensor3< M, I2, I3, float_t >& A_, const matrix_right_t& B_, matrix_out_t& C_ );
186 		//I2*I3 = K;
187 		template< size_t I2, size_t I3 >
188 		bool compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ );
189 
190 		bool compute_t( const matrix_right_t& B_, matrix_out_t& C_ );
191 		bool compute_bt( const matrix_left_t& A_, const matrix_right_t_t& Bt_, matrix_out_t& C_ );
192 		bool compute_t( const matrix_left_t_t& A_, const matrix_right_t_t& B_, matrix_out_t& C_ );
193 		bool compute_vv_outer( const vector_left_t& A_, const vector_right_t& B_, matrix_out_t& C_ );
194 
195 
196 		blas::dgemm_params< float_t > p;
197 
get_paramsvmml::blas_dgemm198 		const blas::dgemm_params< float_t >& get_params(){ return p; };
199 
200 
201 	}; // struct blas_dgemm
202 
203 
204 	template< size_t M, size_t K, size_t N, typename float_t >
blas_dgemm()205 	blas_dgemm< M, K, N, float_t >::blas_dgemm()
206 	{
207 		p.order      = CblasColMajor; //
208 		p.trans_a    = CblasNoTrans;
209 		p.trans_b    = CblasNoTrans;
210 		p.m          = M;
211 		p.n          = N;
212 		p.k          = K;
213 		p.alpha      = 1;
214 		p.a          = 0;
215 		p.lda        = M;
216 		p.b          = 0;
217 		p.ldb        = K; //no transpose
218 		p.beta       = 0;
219 		p.c          = 0;
220 		p.ldc        = M;
221 	}
222 
223 
224 
225 	template< size_t M, size_t K, size_t N, typename float_t >
226 	bool
compute(const matrix_left_t & A_,const matrix_right_t & B_,matrix_out_t & C_)227 	blas_dgemm< M, K, N, float_t >::compute(
228 												const matrix_left_t& A_,
229 												const matrix_right_t& B_,
230 												matrix_out_t& C_
231 											)
232 	{
233 		// blas needs non-const data
234 		matrix_left_t* AA = new matrix_left_t( A_ );
235 		matrix_right_t* BB = new matrix_right_t( B_ );
236 		C_.zero();
237 
238 		p.a         = AA->array;
239 		p.b         = BB->array;
240 		p.c         = C_.array;
241 
242 		blas::dgemm_call< float_t >( p );
243 
244 		//std::cout << p << std::endl; //debug
245 
246 		delete AA;
247 		delete BB;
248 
249 		return true;
250 	}
251 
252 	template< size_t M, size_t K, size_t N, typename float_t >
253 	template< size_t I2, size_t I3 >
254 	bool
compute(const tensor3<M,I2,I3,float_t> & A_,const matrix_right_t & B_,matrix_out_t & C_)255 	blas_dgemm< M, K, N, float_t >::compute(
256 											const tensor3< M, I2, I3, float_t >& A_,
257 											const matrix_right_t& B_,
258 											matrix_out_t& C_
259 											)
260 	{
261 		// blas needs non-const data
262 		tensor3< M, I2, I3, float_t > AA( A_ );
263 		matrix_right_t* BB = new matrix_right_t( B_ );
264 		C_.zero();
265 
266 		p.a         = AA.get_array_ptr();
267 		p.b         = BB->array;
268 		p.c         = C_.array;
269 
270 		blas::dgemm_call< float_t >( p );
271 
272 		//std::cout << p << std::endl; //debug
273 
274 		delete BB;
275 
276 		return true;
277 	}
278 
279 
280 	template< size_t M, size_t K, size_t N, typename float_t >
281 	bool
compute(const matrix_left_t & A_,matrix_out_t & C_)282 	blas_dgemm< M, K, N, float_t >::compute( const matrix_left_t& A_, matrix_out_t& C_ )
283 	{
284 		// blas needs non-const data
285 		matrix_left_t* AA = new matrix_left_t( A_ );
286 		C_.zero();
287 
288 		p.trans_b   = CblasTrans;
289 		p.a         = AA->array;
290 		p.b         = AA->array;
291 		p.ldb       = N;
292 		p.c         = C_.array;
293 
294 		blas::dgemm_call< float_t >( p );
295 
296 		//std::cout << p << std::endl; //debug
297 
298 		delete AA;
299 
300 		return true;
301 	}
302 
303 	template< size_t M, size_t K, size_t N, typename float_t >
304 	template< size_t I2, size_t I3 >
305 	bool
compute(const tensor3<M,I2,I3,float_t> & A_,matrix_out_t & C_)306 	blas_dgemm< M, K, N, float_t >::compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ )
307 	{
308 		// blas needs non-const data
309 		tensor3< M, I2, I3, float_t > AA( A_ ) ;
310 		C_.zero();
311 
312 		p.trans_b   = CblasTrans;
313 		p.a         = AA.get_array_ptr();
314 		p.b         = AA.get_array_ptr();
315 		p.ldb       = N;
316 		p.c         = C_.array;
317 
318 		blas::dgemm_call< float_t >( p );
319 
320 		//std::cout << p << std::endl; //debug
321 
322 		return true;
323 	}
324 
325 	template< size_t M, size_t K, size_t N, typename float_t >
326 	bool
compute_t(const matrix_right_t & B_,matrix_out_t & C_)327 	blas_dgemm< M, K, N, float_t >::compute_t( const matrix_right_t& B_, matrix_out_t& C_ )
328 	{
329 		// blas needs non-const data
330 		matrix_right_t* BB = new matrix_right_t( B_ );
331 		C_.zero();
332 
333 		p.trans_a   = CblasTrans;
334 		p.a         = BB->array;
335 		p.b         = BB->array;
336 		p.lda       = K;
337 		p.c         = C_.array;
338 
339 		blas::dgemm_call< float_t >( p );
340 
341 		//std::cout << p << std::endl; //debug
342 
343 		delete BB;
344 
345 		return true;
346 	}
347 
348 	template< size_t M, size_t K, size_t N, typename float_t >
349 	bool
compute_bt(const matrix_left_t & A_,const matrix_right_t_t & Bt_,matrix_out_t & C_)350 	blas_dgemm< M, K, N, float_t >::compute_bt(
351 											const matrix_left_t& A_,
352 											const matrix_right_t_t& Bt_,
353 											matrix_out_t& C_ )
354 	{
355 		// blas needs non-const data
356 		matrix_left_t* AA = new matrix_left_t( A_ );
357 		matrix_right_t_t* BB = new matrix_right_t_t( Bt_ );
358 		C_.zero();
359 
360 		p.trans_b   = CblasTrans;
361 		p.a         = AA->array;
362 		p.b         = BB->array;
363 		p.c         = C_.array;
364 		p.ldb       = N;
365 
366 		blas::dgemm_call< float_t >( p );
367 
368 		//std::cout << p << std::endl; //debug
369 
370 		delete AA;
371 		delete BB;
372 
373 		return true;
374 	}
375 
376 	template< size_t M, size_t K, size_t N, typename float_t >
377 	bool
compute_t(const matrix_left_t_t & At_,const matrix_right_t_t & Bt_,matrix_out_t & C_)378 	blas_dgemm< M, K, N, float_t >::compute_t(
379 											   const matrix_left_t_t& At_,
380 											   const matrix_right_t_t& Bt_,
381 											   matrix_out_t& C_ )
382 	{
383 		// blas needs non-const data
384 		matrix_left_t_t* AA = new matrix_left_t_t( At_ );
385 		matrix_right_t_t* BB = new matrix_right_t_t( Bt_ );
386 		C_.zero();
387 
388 		p.trans_a   = CblasTrans;
389 		p.trans_b   = CblasTrans;
390 		p.a         = AA->array;
391 		p.b         = BB->array;
392 		p.c         = C_.array;
393 		p.ldb       = N;
394 		p.lda       = K;
395 
396 		blas::dgemm_call< float_t >( p );
397 
398 		//std::cout << p << std::endl; //debug
399 
400 		delete AA;
401 		delete BB;
402 
403 		return true;
404 	}
405 
406 	template< size_t M, size_t K, size_t N, typename float_t >
407 	bool
compute_vv_outer(const vector_left_t & A_,const vector_right_t & B_,matrix_out_t & C_)408 	blas_dgemm< M, K, N, float_t >::compute_vv_outer(
409 											  const vector_left_t& A_,
410 											  const vector_right_t& B_,
411 											  matrix_out_t& C_ )
412 	{
413 		// blas needs non-const data
414 		vector_left_t* AA = new vector_left_t( A_ );
415 		vector_right_t* BB = new vector_right_t( B_ );
416 		C_.zero();
417 
418 		p.trans_a   = CblasTrans;
419 		p.a         = AA->array;
420 		p.b         = BB->array;
421 		p.c         = C_.array;
422 		p.lda       = K;
423 
424 		blas::dgemm_call< float_t >( p );
425 
426 		//std::cout << p << std::endl; //debug
427 
428 		delete AA;
429 		delete BB;
430 
431 		return true;
432 	}
433 
434 
435 } // namespace vmml
436 
437 #endif
438 
439