1 // SPDX-License-Identifier: Apache-2.0 2 // 3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) 4 // Copyright 2008-2016 National ICT Australia (NICTA) 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ------------------------------------------------------------------------ 17 18 19 //! \addtogroup gemm 20 //! @{ 21 22 23 24 //! for tiny square matrices, size <= 4x4 25 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> 26 class gemm_emul_tinysq 27 { 28 public: 29 30 31 template<typename eT, typename TA, typename TB> 32 arma_cold 33 inline 34 static 35 void apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))36 apply 37 ( 38 Mat<eT>& C, 39 const TA& A, 40 const TB& B, 41 const eT alpha = eT(1), 42 const eT beta = eT(0) 43 ) 44 { 45 arma_extra_debug_sigprint(); 46 47 switch(A.n_rows) 48 { 49 case 4: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta ); 50 // fallthrough 51 case 3: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta ); 52 // fallthrough 53 case 2: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta ); 54 // fallthrough 55 case 1: gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta ); 56 // fallthrough 57 default: ; 58 } 59 } 60 61 }; 62 63 64 65 //! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes) 66 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 67 class gemm_emul_large 68 { 69 public: 70 71 template<typename eT, typename TA, typename TB> 72 arma_hot 73 inline 74 static 75 void apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))76 apply 77 ( 78 Mat<eT>& C, 79 const TA& A, 80 const TB& B, 81 const eT alpha = eT(1), 82 const eT beta = eT(0) 83 ) 84 { 85 arma_extra_debug_sigprint(); 86 87 const uword A_n_rows = A.n_rows; 88 const uword A_n_cols = A.n_cols; 89 90 const uword B_n_rows = B.n_rows; 91 const uword B_n_cols = B.n_cols; 92 93 if( (do_trans_A == false) && (do_trans_B == false) ) 94 { 95 arma_aligned podarray<eT> tmp(A_n_cols); 96 97 eT* A_rowdata = tmp.memptr(); 98 99 for(uword row_A=0; row_A < A_n_rows; ++row_A) 100 { 101 tmp.copy_row(A, row_A); 102 103 for(uword col_B=0; col_B < B_n_cols; ++col_B) 104 { 105 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B)); 106 107 if( (use_alpha == false) && (use_beta == false) ) { C.at(row_A,col_B) = acc; } 108 else if( (use_alpha == true ) && (use_beta == false) ) { C.at(row_A,col_B) = alpha*acc; } 109 else if( (use_alpha == false) && (use_beta == true ) ) { C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); } 110 else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); } 111 } 112 } 113 } 114 else 115 if( (do_trans_A == true) && (do_trans_B == false) ) 116 { 117 for(uword col_A=0; col_A < A_n_cols; ++col_A) 118 { 119 // col_A is interpreted as row_A when storing the results in matrix C 120 121 const eT* A_coldata = A.colptr(col_A); 122 123 for(uword col_B=0; col_B < B_n_cols; ++col_B) 124 { 125 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B)); 126 127 if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,col_B) = acc; } 128 else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,col_B) = alpha*acc; } 129 else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); } 130 else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); } 131 } 132 } 133 } 134 else 135 if( (do_trans_A == false) && (do_trans_B == true) ) 136 { 137 Mat<eT> BB; 138 op_strans::apply_mat_noalias(BB, B); 139 140 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); 141 } 142 else 143 if( (do_trans_A == true) && (do_trans_B == true) ) 144 { 145 // mat B_tmp = trans(B); 146 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); 147 148 149 // By using the trans(A)*trans(B) = trans(B*A) equivalency, 150 // transpose operations are not needed 151 152 arma_aligned podarray<eT> tmp(B.n_cols); 153 eT* B_rowdata = tmp.memptr(); 154 155 for(uword row_B=0; row_B < B_n_rows; ++row_B) 156 { 157 tmp.copy_row(B, row_B); 158 159 for(uword col_A=0; col_A < A_n_cols; ++col_A) 160 { 161 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A)); 162 163 if( (use_alpha == false) && (use_beta == false) ) { C.at(col_A,row_B) = acc; } 164 else if( (use_alpha == true ) && (use_beta == false) ) { C.at(col_A,row_B) = alpha*acc; } 165 else if( (use_alpha == false) && (use_beta == true ) ) { C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); } 166 else if( (use_alpha == true ) && (use_beta == true ) ) { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); } 167 } 168 } 169 } 170 } 171 172 }; 173 174 175 176 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 177 class gemm_emul 178 { 179 public: 180 181 182 template<typename eT, typename TA, typename TB> 183 arma_hot 184 inline 185 static 186 void apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_not_cx<eT>::result * junk=nullptr)187 apply 188 ( 189 Mat<eT>& C, 190 const TA& A, 191 const TB& B, 192 const eT alpha = eT(1), 193 const eT beta = eT(0), 194 const typename arma_not_cx<eT>::result* junk = nullptr 195 ) 196 { 197 arma_extra_debug_sigprint(); 198 arma_ignore(junk); 199 200 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta); 201 } 202 203 204 205 template<typename eT> 206 arma_hot 207 inline 208 static 209 void apply(Mat<eT> & C,const Mat<eT> & A,const Mat<eT> & B,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_cx_only<eT>::result * junk=nullptr)210 apply 211 ( 212 Mat<eT>& C, 213 const Mat<eT>& A, 214 const Mat<eT>& B, 215 const eT alpha = eT(1), 216 const eT beta = eT(0), 217 const typename arma_cx_only<eT>::result* junk = nullptr 218 ) 219 { 220 arma_extra_debug_sigprint(); 221 arma_ignore(junk); 222 223 // "better than nothing" handling of hermitian transposes for complex number matrices 224 225 Mat<eT> tmp_A; 226 Mat<eT> tmp_B; 227 228 if(do_trans_A) { op_htrans::apply_mat_noalias(tmp_A, A); } 229 if(do_trans_B) { op_htrans::apply_mat_noalias(tmp_B, B); } 230 231 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; 232 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B; 233 234 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); 235 } 236 237 }; 238 239 240 241 //! \brief 242 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm. 243 //! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes) 244 245 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> 246 class gemm 247 { 248 public: 249 250 template<typename eT, typename TA, typename TB> 251 inline 252 static 253 void apply_blas_type(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))254 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) 255 { 256 arma_extra_debug_sigprint(); 257 258 if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (A.n_rows == B.n_rows) && (B.n_rows == B.n_cols) && (is_cx<eT>::no) ) 259 { 260 if(do_trans_B == false) 261 { 262 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta); 263 } 264 else 265 { 266 Mat<eT> BB(B.n_rows, B.n_rows, arma_nozeros_indicator()); 267 268 op_strans::apply_mat_noalias_tinysq(BB, B); 269 270 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); 271 } 272 } 273 else 274 { 275 #if defined(ARMA_USE_ATLAS) 276 { 277 arma_extra_debug_print("atlas::cblas_gemm()"); 278 279 arma_debug_assert_atlas_size(A,B); 280 281 atlas::cblas_gemm<eT> 282 ( 283 atlas::CblasColMajor, 284 (do_trans_A) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, 285 (do_trans_B) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, 286 C.n_rows, 287 C.n_cols, 288 (do_trans_A) ? A.n_rows : A.n_cols, 289 (use_alpha) ? alpha : eT(1), 290 A.mem, 291 (do_trans_A) ? A.n_rows : C.n_rows, 292 B.mem, 293 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ), 294 (use_beta) ? beta : eT(0), 295 C.memptr(), 296 C.n_rows 297 ); 298 } 299 #elif defined(ARMA_USE_BLAS) 300 { 301 arma_extra_debug_print("blas::gemm()"); 302 303 arma_debug_assert_blas_size(A,B); 304 305 const char trans_A = (do_trans_A) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N'; 306 const char trans_B = (do_trans_B) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N'; 307 308 const blas_int m = blas_int(C.n_rows); 309 const blas_int n = blas_int(C.n_cols); 310 const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols); 311 312 const eT local_alpha = (use_alpha) ? alpha : eT(1); 313 314 const blas_int lda = (do_trans_A) ? k : m; 315 const blas_int ldb = (do_trans_B) ? n : k; 316 317 const eT local_beta = (use_beta) ? beta : eT(0); 318 319 arma_extra_debug_print( arma_str::format("blas::gemm(): trans_A = %c") % trans_A ); 320 arma_extra_debug_print( arma_str::format("blas::gemm(): trans_B = %c") % trans_B ); 321 322 blas::gemm<eT> 323 ( 324 &trans_A, 325 &trans_B, 326 &m, 327 &n, 328 &k, 329 &local_alpha, 330 A.mem, 331 &lda, 332 B.mem, 333 &ldb, 334 &local_beta, 335 C.memptr(), 336 &m 337 ); 338 } 339 #else 340 { 341 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); 342 } 343 #endif 344 } 345 } 346 347 348 349 //! immediate multiplication of matrices A and B, storing the result in C 350 template<typename eT, typename TA, typename TB> 351 inline 352 static 353 void apply(Mat<eT> & C,const TA & A,const TB & B,const eT alpha=eT (1),const eT beta=eT (0))354 apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) 355 { 356 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); 357 } 358 359 360 361 template<typename TA, typename TB> 362 arma_inline 363 static 364 void apply(Mat<float> & C,const TA & A,const TB & B,const float alpha=float (1),const float beta=float (0))365 apply 366 ( 367 Mat<float>& C, 368 const TA& A, 369 const TB& B, 370 const float alpha = float(1), 371 const float beta = float(0) 372 ) 373 { 374 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 375 } 376 377 378 379 template<typename TA, typename TB> 380 arma_inline 381 static 382 void apply(Mat<double> & C,const TA & A,const TB & B,const double alpha=double (1),const double beta=double (0))383 apply 384 ( 385 Mat<double>& C, 386 const TA& A, 387 const TB& B, 388 const double alpha = double(1), 389 const double beta = double(0) 390 ) 391 { 392 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 393 } 394 395 396 397 template<typename TA, typename TB> 398 arma_inline 399 static 400 void apply(Mat<std::complex<float>> & C,const TA & A,const TB & B,const std::complex<float> alpha=std::complex<float> (1),const std::complex<float> beta=std::complex<float> (0))401 apply 402 ( 403 Mat< std::complex<float> >& C, 404 const TA& A, 405 const TB& B, 406 const std::complex<float> alpha = std::complex<float>(1), 407 const std::complex<float> beta = std::complex<float>(0) 408 ) 409 { 410 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 411 } 412 413 414 415 template<typename TA, typename TB> 416 arma_inline 417 static 418 void apply(Mat<std::complex<double>> & C,const TA & A,const TB & B,const std::complex<double> alpha=std::complex<double> (1),const std::complex<double> beta=std::complex<double> (0))419 apply 420 ( 421 Mat< std::complex<double> >& C, 422 const TA& A, 423 const TB& B, 424 const std::complex<double> alpha = std::complex<double>(1), 425 const std::complex<double> beta = std::complex<double>(0) 426 ) 427 { 428 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); 429 } 430 431 }; 432 433 434 435 //! @} 436