1 // SPDX-License-Identifier: Apache-2.0 2 // 3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) 4 // Copyright 2008-2016 National ICT Australia (NICTA) 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ------------------------------------------------------------------------ 17 18 19 //! \addtogroup herk 20 //! @{ 21 22 23 24 class herk_helper 25 { 26 public: 27 28 template<typename eT> 29 inline 30 static 31 void inplace_conj_copy_upper_tri_to_lower_tri(Mat<eT> & C)32 inplace_conj_copy_upper_tri_to_lower_tri(Mat<eT>& C) 33 { 34 // under the assumption that C is a square matrix 35 36 const uword N = C.n_rows; 37 38 for(uword k=0; k < N; ++k) 39 { 40 eT* colmem = C.colptr(k); 41 42 for(uword i=(k+1); i < N; ++i) 43 { 44 colmem[i] = std::conj( C.at(k,i) ); 45 } 46 } 47 } 48 49 50 template<typename eT> 51 static 52 arma_hot 53 inline 54 eT dot_conj_row(const uword n_elem,const eT * const A,const Mat<eT> & B,const uword row)55 dot_conj_row(const uword n_elem, const eT* const A, const Mat<eT>& B, const uword row) 56 { 57 arma_extra_debug_sigprint(); 58 59 typedef typename get_pod_type<eT>::result T; 60 61 T val_real = T(0); 62 T val_imag = T(0); 63 64 for(uword i=0; i<n_elem; ++i) 65 { 66 const std::complex<T>& X = A[i]; 67 const std::complex<T>& Y = B.at(row,i); 68 69 const T a = X.real(); 70 const T b = X.imag(); 71 72 const T c = Y.real(); 73 const T d = Y.imag(); 74 75 val_real += (a*c) + (b*d); 76 val_imag += (b*c) - (a*d); 77 } 78 79 return std::complex<T>(val_real, val_imag); 80 } 81 82 }; 83 84 85 86 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> 87 class herk_vec 88 { 89 public: 90 91 template<typename T, typename TA> 92 arma_hot 93 inline 94 static 95 void apply(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))96 apply 97 ( 98 Mat< std::complex<T> >& C, 99 const TA& A, 100 const T alpha = T(1), 101 const T beta = T(0) 102 ) 103 { 104 arma_extra_debug_sigprint(); 105 106 typedef std::complex<T> eT; 107 108 const uword A_n_rows = A.n_rows; 109 const uword A_n_cols = A.n_cols; 110 111 // for beta != 0, C is assumed to be hermitian 112 113 // do_trans_A == false -> C = alpha * A * A^H + beta*C 114 // do_trans_A == true -> C = alpha * A^H * A + beta*C 115 116 const eT* A_mem = A.memptr(); 117 118 if(do_trans_A == false) 119 { 120 if(A_n_rows == 1) 121 { 122 const eT acc = op_cdot::direct_cdot(A_n_cols, A_mem, A_mem); 123 124 if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc; } 125 else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc; } 126 else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc + beta*C[0]; } 127 else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc + beta*C[0]; } 128 } 129 else 130 for(uword row_A=0; row_A < A_n_rows; ++row_A) 131 { 132 const eT& A_rowdata = A_mem[row_A]; 133 134 for(uword k=row_A; k < A_n_rows; ++k) 135 { 136 const eT acc = A_rowdata * std::conj( A_mem[k] ); 137 138 if( (use_alpha == false) && (use_beta == false) ) 139 { 140 C.at(row_A, k) = acc; 141 if(row_A != k) { C.at(k, row_A) = std::conj(acc); } 142 } 143 else 144 if( (use_alpha == true) && (use_beta == false) ) 145 { 146 const eT val = alpha*acc; 147 148 C.at(row_A, k) = val; 149 if(row_A != k) { C.at(k, row_A) = std::conj(val); } 150 } 151 else 152 if( (use_alpha == false) && (use_beta == true) ) 153 { 154 C.at(row_A, k) = acc + beta*C.at(row_A, k); 155 if(row_A != k) { C.at(k, row_A) = std::conj(acc) + beta*C.at(k, row_A); } 156 } 157 else 158 if( (use_alpha == true) && (use_beta == true) ) 159 { 160 const eT val = alpha*acc; 161 162 C.at(row_A, k) = val + beta*C.at(row_A, k); 163 if(row_A != k) { C.at(k, row_A) = std::conj(val) + beta*C.at(k, row_A); } 164 } 165 } 166 } 167 } 168 else 169 if(do_trans_A == true) 170 { 171 if(A_n_cols == 1) 172 { 173 const eT acc = op_cdot::direct_cdot(A_n_rows, A_mem, A_mem); 174 175 if( (use_alpha == false) && (use_beta == false) ) { C[0] = acc; } 176 else if( (use_alpha == true ) && (use_beta == false) ) { C[0] = alpha*acc; } 177 else if( (use_alpha == false) && (use_beta == true ) ) { C[0] = acc + beta*C[0]; } 178 else if( (use_alpha == true ) && (use_beta == true ) ) { C[0] = alpha*acc + beta*C[0]; } 179 } 180 else 181 for(uword col_A=0; col_A < A_n_cols; ++col_A) 182 { 183 // col_A is interpreted as row_A when storing the results in matrix C 184 185 const eT A_coldata = std::conj( A_mem[col_A] ); 186 187 for(uword k=col_A; k < A_n_cols ; ++k) 188 { 189 const eT acc = A_coldata * A_mem[k]; 190 191 if( (use_alpha == false) && (use_beta == false) ) 192 { 193 C.at(col_A, k) = acc; 194 if(col_A != k) { C.at(k, col_A) = std::conj(acc); } 195 } 196 else 197 if( (use_alpha == true ) && (use_beta == false) ) 198 { 199 const eT val = alpha*acc; 200 201 C.at(col_A, k) = val; 202 if(col_A != k) { C.at(k, col_A) = std::conj(val); } 203 } 204 else 205 if( (use_alpha == false) && (use_beta == true ) ) 206 { 207 C.at(col_A, k) = acc + beta*C.at(col_A, k); 208 if(col_A != k) { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); } 209 } 210 else 211 if( (use_alpha == true ) && (use_beta == true ) ) 212 { 213 const eT val = alpha*acc; 214 215 C.at(col_A, k) = val + beta*C.at(col_A, k); 216 if(col_A != k) { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); } 217 } 218 } 219 } 220 } 221 } 222 223 }; 224 225 226 227 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> 228 class herk_emul 229 { 230 public: 231 232 template<typename T, typename TA> 233 arma_hot 234 inline 235 static 236 void apply(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))237 apply 238 ( 239 Mat< std::complex<T> >& C, 240 const TA& A, 241 const T alpha = T(1), 242 const T beta = T(0) 243 ) 244 { 245 arma_extra_debug_sigprint(); 246 247 typedef std::complex<T> eT; 248 249 // do_trans_A == false -> C = alpha * A * A^H + beta*C 250 // do_trans_A == true -> C = alpha * A^H * A + beta*C 251 252 if(do_trans_A == false) 253 { 254 Mat<eT> AA; 255 256 op_htrans::apply_mat_noalias(AA, A); 257 258 herk_emul<true, use_alpha, use_beta>::apply(C, AA, alpha, beta); 259 } 260 else 261 if(do_trans_A == true) 262 { 263 const uword A_n_rows = A.n_rows; 264 const uword A_n_cols = A.n_cols; 265 266 for(uword col_A=0; col_A < A_n_cols; ++col_A) 267 { 268 // col_A is interpreted as row_A when storing the results in matrix C 269 270 const eT* A_coldata = A.colptr(col_A); 271 272 for(uword k=col_A; k < A_n_cols ; ++k) 273 { 274 const eT acc = op_cdot::direct_cdot(A_n_rows, A_coldata, A.colptr(k)); 275 276 if( (use_alpha == false) && (use_beta == false) ) 277 { 278 C.at(col_A, k) = acc; 279 if(col_A != k) { C.at(k, col_A) = std::conj(acc); } 280 } 281 else 282 if( (use_alpha == true) && (use_beta == false) ) 283 { 284 const eT val = alpha*acc; 285 286 C.at(col_A, k) = val; 287 if(col_A != k) { C.at(k, col_A) = std::conj(val); } 288 } 289 else 290 if( (use_alpha == false) && (use_beta == true) ) 291 { 292 C.at(col_A, k) = acc + beta*C.at(col_A, k); 293 if(col_A != k) { C.at(k, col_A) = std::conj(acc) + beta*C.at(k, col_A); } 294 } 295 else 296 if( (use_alpha == true) && (use_beta == true) ) 297 { 298 const eT val = alpha*acc; 299 300 C.at(col_A, k) = val + beta*C.at(col_A, k); 301 if(col_A != k) { C.at(k, col_A) = std::conj(val) + beta*C.at(k, col_A); } 302 } 303 } 304 } 305 } 306 } 307 308 }; 309 310 311 312 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> 313 class herk 314 { 315 public: 316 317 template<typename T, typename TA> 318 inline 319 static 320 void apply_blas_type(Mat<std::complex<T>> & C,const TA & A,const T alpha=T (1),const T beta=T (0))321 apply_blas_type( Mat<std::complex<T>>& C, const TA& A, const T alpha = T(1), const T beta = T(0) ) 322 { 323 arma_extra_debug_sigprint(); 324 325 const uword threshold = 16; 326 327 if(A.is_vec()) 328 { 329 // work around poor handling of vectors by herk() in ATLAS 3.8.4 and standard BLAS 330 331 herk_vec<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta); 332 333 return; 334 } 335 336 337 if( (A.n_elem <= threshold) ) 338 { 339 herk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta); 340 } 341 else 342 { 343 #if defined(ARMA_USE_ATLAS) 344 { 345 if(use_beta == true) 346 { 347 typedef typename std::complex<T> eT; 348 349 // use a temporary matrix, as we can't assume that matrix C is already symmetric 350 Mat<eT> D(C.n_rows, C.n_cols, arma_nozeros_indicator()); 351 352 herk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha); 353 354 // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 355 arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); 356 357 return; 358 } 359 360 atlas::cblas_herk<T> 361 ( 362 atlas::CblasColMajor, 363 atlas::CblasUpper, 364 (do_trans_A) ? CblasConjTrans : atlas::CblasNoTrans, 365 C.n_cols, 366 (do_trans_A) ? A.n_rows : A.n_cols, 367 (use_alpha) ? alpha : T(1), 368 A.mem, 369 (do_trans_A) ? A.n_rows : C.n_cols, 370 (use_beta) ? beta : T(0), 371 C.memptr(), 372 C.n_cols 373 ); 374 375 herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C); 376 } 377 #elif defined(ARMA_USE_BLAS) 378 { 379 if(use_beta == true) 380 { 381 typedef typename std::complex<T> eT; 382 383 // use a temporary matrix, as we can't assume that matrix C is already symmetric 384 Mat<eT> D(C.n_rows, C.n_cols, arma_nozeros_indicator()); 385 386 herk<do_trans_A, use_alpha, false>::apply_blas_type(D,A,alpha); 387 388 // NOTE: assuming beta=1; this is okay for now, as currently glue_times only uses beta=1 389 arrayops::inplace_plus(C.memptr(), D.memptr(), C.n_elem); 390 391 return; 392 } 393 394 arma_extra_debug_print("blas::herk()"); 395 396 const char uplo = 'U'; 397 398 const char trans_A = (do_trans_A) ? 'C' : 'N'; 399 400 const blas_int n = blas_int(C.n_cols); 401 const blas_int k = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols); 402 403 const T local_alpha = (use_alpha) ? alpha : T(1); 404 const T local_beta = (use_beta) ? beta : T(0); 405 406 const blas_int lda = (do_trans_A) ? k : n; 407 408 arma_extra_debug_print( arma_str::format("blas::herk(): trans_A = %c") % trans_A ); 409 410 blas::herk<T> 411 ( 412 &uplo, 413 &trans_A, 414 &n, 415 &k, 416 &local_alpha, 417 A.mem, 418 &lda, 419 &local_beta, 420 C.memptr(), 421 &n // &ldc 422 ); 423 424 herk_helper::inplace_conj_copy_upper_tri_to_lower_tri(C); 425 } 426 #else 427 { 428 herk_emul<do_trans_A, use_alpha, use_beta>::apply(C,A,alpha,beta); 429 } 430 #endif 431 } 432 433 } 434 435 436 437 template<typename eT, typename TA> 438 inline 439 static 440 void apply(Mat<eT> & C,const TA & A,const eT alpha=eT (1),const eT beta=eT (0),const typename arma_not_cx<eT>::result * junk=nullptr)441 apply( Mat<eT>& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = nullptr ) 442 { 443 arma_ignore(C); 444 arma_ignore(A); 445 arma_ignore(alpha); 446 arma_ignore(beta); 447 arma_ignore(junk); 448 449 // herk() cannot be used by non-complex matrices 450 451 return; 452 } 453 454 455 456 template<typename TA> 457 arma_inline 458 static 459 void apply(Mat<std::complex<float>> & C,const TA & A,const float alpha=float (1),const float beta=float (0))460 apply 461 ( 462 Mat< std::complex<float> >& C, 463 const TA& A, 464 const float alpha = float(1), 465 const float beta = float(0) 466 ) 467 { 468 herk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta); 469 } 470 471 472 473 template<typename TA> 474 arma_inline 475 static 476 void apply(Mat<std::complex<double>> & C,const TA & A,const double alpha=double (1),const double beta=double (0))477 apply 478 ( 479 Mat< std::complex<double> >& C, 480 const TA& A, 481 const double alpha = double(1), 482 const double beta = double(0) 483 ) 484 { 485 herk<do_trans_A, use_alpha, use_beta>::apply_blas_type(C,A,alpha,beta); 486 } 487 488 }; 489 490 491 492 //! @} 493