1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup gemm
20 //! @{
21 
22 
23 
24 //! for tiny square matrices, size <= 4x4
25 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
26 class gemm_emul_tinysq
27   {
28   public:
29 
30 
31   template<typename eT, typename TA, typename TB>
32   arma_cold
33   inline
34   static
35   void
apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))36   apply
37     (
38           Mat<eT>& C,
39     const TA&      A,
40     const TB&      B,
41     const eT       alpha = eT(1),
42     const eT       beta  = eT(0)
43     )
44     {
45     arma_extra_debug_sigprint();
46 
47     switch(A.n_rows)
48       {
49       case  4:  gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
50       // fallthrough
51       case  3:  gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
52       // fallthrough
53       case  2:  gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
54       // fallthrough
55       case  1:  gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
56       // fallthrough
57       default:  ;
58       }
59     }
60 
61   };
62 
63 
64 
65 //! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
66 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
67 class gemm_emul_large
68   {
69   public:
70 
71   template<typename eT, typename TA, typename TB>
72   arma_hot
73   inline
74   static
75   void
apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))76   apply
77     (
78           Mat<eT>& C,
79     const TA&      A,
80     const TB&      B,
81     const eT       alpha = eT(1),
82     const eT       beta  = eT(0)
83     )
84     {
85     arma_extra_debug_sigprint();
86 
87     const uword A_n_rows = A.n_rows;
88     const uword A_n_cols = A.n_cols;
89 
90     const uword B_n_rows = B.n_rows;
91     const uword B_n_cols = B.n_cols;
92 
93     if( (do_trans_A == false) && (do_trans_B == false) )
94       {
95       arma_aligned podarray<eT> tmp(A_n_cols);
96 
97       eT* A_rowdata = tmp.memptr();
98 
99       for(uword row_A=0; row_A < A_n_rows; ++row_A)
100         {
101         tmp.copy_row(A, row_A);
102 
103         for(uword col_B=0; col_B < B_n_cols; ++col_B)
104           {
105           const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
106 
107                if( (use_alpha == false) && (use_beta == false) )  { C.at(row_A,col_B) =       acc;                          }
108           else if( (use_alpha == true ) && (use_beta == false) )  { C.at(row_A,col_B) = alpha*acc;                          }
109           else if( (use_alpha == false) && (use_beta == true ) )  { C.at(row_A,col_B) =       acc + beta*C.at(row_A,col_B); }
110           else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); }
111           }
112         }
113       }
114     else
115     if( (do_trans_A == true) && (do_trans_B == false) )
116       {
117       for(uword col_A=0; col_A < A_n_cols; ++col_A)
118         {
119         // col_A is interpreted as row_A when storing the results in matrix C
120 
121         const eT* A_coldata = A.colptr(col_A);
122 
123         for(uword col_B=0; col_B < B_n_cols; ++col_B)
124           {
125           const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
126 
127                if( (use_alpha == false) && (use_beta == false) )  { C.at(col_A,col_B) =       acc;                          }
128           else if( (use_alpha == true ) && (use_beta == false) )  { C.at(col_A,col_B) = alpha*acc;                          }
129           else if( (use_alpha == false) && (use_beta == true ) )  { C.at(col_A,col_B) =       acc + beta*C.at(col_A,col_B); }
130           else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); }
131           }
132         }
133       }
134     else
135     if( (do_trans_A == false) && (do_trans_B == true) )
136       {
137       Mat<eT> BB;
138       op_strans::apply_mat_noalias(BB, B);
139 
140       gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
141       }
142     else
143     if( (do_trans_A == true) && (do_trans_B == true) )
144       {
145       // mat B_tmp = trans(B);
146       // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
147 
148 
149       // By using the trans(A)*trans(B) = trans(B*A) equivalency,
150       // transpose operations are not needed
151 
152       arma_aligned podarray<eT> tmp(B.n_cols);
153       eT* B_rowdata = tmp.memptr();
154 
155       for(uword row_B=0; row_B < B_n_rows; ++row_B)
156         {
157         tmp.copy_row(B, row_B);
158 
159         for(uword col_A=0; col_A < A_n_cols; ++col_A)
160           {
161           const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
162 
163                if( (use_alpha == false) && (use_beta == false) )  { C.at(col_A,row_B) =       acc;                          }
164           else if( (use_alpha == true ) && (use_beta == false) )  { C.at(col_A,row_B) = alpha*acc;                          }
165           else if( (use_alpha == false) && (use_beta == true ) )  { C.at(col_A,row_B) =       acc + beta*C.at(col_A,row_B); }
166           else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); }
167           }
168         }
169       }
170     }
171 
172   };
173 
174 
175 
176 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
177 class gemm_emul
178   {
179   public:
180 
181 
182   template<typename eT, typename TA, typename TB>
183   arma_hot
184   inline
185   static
186   void
apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_not_cx<eT>::result * junk=nullptr)187   apply
188     (
189           Mat<eT>& C,
190     const TA&      A,
191     const TB&      B,
192     const eT       alpha = eT(1),
193     const eT       beta  = eT(0),
194     const typename arma_not_cx<eT>::result* junk = nullptr
195     )
196     {
197     arma_extra_debug_sigprint();
198     arma_ignore(junk);
199 
200     gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
201     }
202 
203 
204 
205   template<typename eT>
206   arma_hot
207   inline
208   static
209   void
apply(Mat<eT> & C,const Mat<eT> & A,const Mat<eT> & B,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_cx_only<eT>::result * junk=nullptr)210   apply
211     (
212           Mat<eT>& C,
213     const Mat<eT>& A,
214     const Mat<eT>& B,
215     const eT       alpha = eT(1),
216     const eT       beta  = eT(0),
217     const typename arma_cx_only<eT>::result* junk = nullptr
218     )
219     {
220     arma_extra_debug_sigprint();
221     arma_ignore(junk);
222 
223     // "better than nothing" handling of hermitian transposes for complex number matrices
224 
225     Mat<eT> tmp_A;
226     Mat<eT> tmp_B;
227 
228     if(do_trans_A)  { op_htrans::apply_mat_noalias(tmp_A, A); }
229     if(do_trans_B)  { op_htrans::apply_mat_noalias(tmp_B, B); }
230 
231     const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
232     const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
233 
234     gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
235     }
236 
237   };
238 
239 
240 
241 //! \brief
242 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
243 //! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes)
244 
245 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
246 class gemm
247   {
248   public:
249 
250   template<typename eT, typename TA, typename TB>
251   inline
252   static
253   void
apply_blas_type(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))254   apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
255     {
256     arma_extra_debug_sigprint();
257 
258     if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (A.n_rows == B.n_rows) && (B.n_rows == B.n_cols) && (is_cx<eT>::no) )
259       {
260       if(do_trans_B == false)
261         {
262         gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
263         }
264       else
265         {
266         Mat<eT> BB(B.n_rows, B.n_rows, arma_nozeros_indicator());
267 
268         op_strans::apply_mat_noalias_tinysq(BB, B);
269 
270         gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
271         }
272       }
273     else
274       {
275       #if defined(ARMA_USE_ATLAS)
276         {
277         arma_extra_debug_print("atlas::cblas_gemm()");
278 
279         arma_debug_assert_atlas_size(A,B);
280 
281         atlas::cblas_gemm<eT>
282           (
283           atlas::CblasColMajor,
284           (do_trans_A) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
285           (do_trans_B) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
286           C.n_rows,
287           C.n_cols,
288           (do_trans_A) ? A.n_rows : A.n_cols,
289           (use_alpha) ? alpha : eT(1),
290           A.mem,
291           (do_trans_A) ? A.n_rows : C.n_rows,
292           B.mem,
293           (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
294           (use_beta) ? beta : eT(0),
295           C.memptr(),
296           C.n_rows
297           );
298         }
299       #elif defined(ARMA_USE_BLAS)
300         {
301         arma_extra_debug_print("blas::gemm()");
302 
303         arma_debug_assert_blas_size(A,B);
304 
305         const char trans_A = (do_trans_A) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
306         const char trans_B = (do_trans_B) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
307 
308         const blas_int m   = blas_int(C.n_rows);
309         const blas_int n   = blas_int(C.n_cols);
310         const blas_int k   = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols);
311 
312         const eT local_alpha = (use_alpha) ? alpha : eT(1);
313 
314         const blas_int lda = (do_trans_A) ? k : m;
315         const blas_int ldb = (do_trans_B) ? n : k;
316 
317         const eT local_beta  = (use_beta) ? beta : eT(0);
318 
319         arma_extra_debug_print( arma_str::format("blas::gemm(): trans_A = %c") % trans_A );
320         arma_extra_debug_print( arma_str::format("blas::gemm(): trans_B = %c") % trans_B );
321 
322         blas::gemm<eT>
323           (
324           &trans_A,
325           &trans_B,
326           &m,
327           &n,
328           &k,
329           &local_alpha,
330           A.mem,
331           &lda,
332           B.mem,
333           &ldb,
334           &local_beta,
335           C.memptr(),
336           &m
337           );
338         }
339       #else
340         {
341         gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
342         }
343       #endif
344       }
345     }
346 
347 
348 
349   //! immediate multiplication of matrices A and B, storing the result in C
350   template<typename eT, typename TA, typename TB>
351   inline
352   static
353   void
apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))354   apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
355     {
356     gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
357     }
358 
359 
360 
361   template<typename TA, typename TB>
362   arma_inline
363   static
364   void
apply(Mat<float> & C,const TA & A,const TB & B,const float alpha=float (1),const float beta=float (0))365   apply
366     (
367           Mat<float>& C,
368     const TA&         A,
369     const TB&         B,
370     const float alpha = float(1),
371     const float beta  = float(0)
372     )
373     {
374     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
375     }
376 
377 
378 
379   template<typename TA, typename TB>
380   arma_inline
381   static
382   void
apply(Mat<double> & C,const TA & A,const TB & B,const double alpha=double (1),const double beta=double (0))383   apply
384     (
385           Mat<double>& C,
386     const TA&          A,
387     const TB&          B,
388     const double alpha = double(1),
389     const double beta  = double(0)
390     )
391     {
392     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
393     }
394 
395 
396 
397   template<typename TA, typename TB>
398   arma_inline
399   static
400   void
apply(Mat<std::complex<float>> & C,const TA & A,const TB & B,const std::complex<float> alpha=std::complex<float> (1),const std::complex<float> beta=std::complex<float> (0))401   apply
402     (
403           Mat< std::complex<float> >& C,
404     const TA&                         A,
405     const TB&                         B,
406     const std::complex<float> alpha = std::complex<float>(1),
407     const std::complex<float> beta  = std::complex<float>(0)
408     )
409     {
410     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
411     }
412 
413 
414 
415   template<typename TA, typename TB>
416   arma_inline
417   static
418   void
apply(Mat<std::complex<double>> & C,const TA & A,const TB & B,const std::complex<double> alpha=std::complex<double> (1),const std::complex<double> beta=std::complex<double> (0))419   apply
420     (
421           Mat< std::complex<double> >& C,
422     const TA&                          A,
423     const TB&                          B,
424     const std::complex<double> alpha = std::complex<double>(1),
425     const std::complex<double> beta  = std::complex<double>(0)
426     )
427     {
428     gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
429     }
430 
431   };
432 
433 
434 
435 //! @}
436