1/* fflas/fflas_pfgemm.inl 2 * Copyright (C) 2013 Jean Guillaume Dumas Clement Pernet Ziad Sultan 3 *<ziad.sultan@imag.fr> 4 * 5 * ========LICENCE======== 6 * This file is part of the library FFLAS-FFPACK. 7 * 8 * FFLAS-FFPACK is free software: you can redistribute it and/or modify 9 * it under the terms of the GNU Lesser General Public 10 * License as published by the Free Software Foundation; either 11 * version 2.1 of the License, or (at your option) any later version. 12 * 13 * This library is distributed in the hope that it will be useful, 14 * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 16 * Lesser General Public License for more details. 17 * 18 * You should have received a copy of the GNU Lesser General Public 19 * License along with this library; if not, write to the Free Software 20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 21 * ========LICENCE======== 22 *. 23 */ 24 25 26 27namespace FFLAS 28{ 29 30 31 template<class Field, class AlgoT, class FieldTrait> 32 typename Field::Element* 33 pfgemm(const Field& F, 34 const FFLAS_TRANSPOSE ta, 35 const FFLAS_TRANSPOSE tb, 36 const size_t m, 37 const size_t n, 38 const size_t k, 39 const typename Field::Element alpha, 40 const typename Field::ConstElement_ptr A, const size_t lda, 41 const typename Field::ConstElement_ptr B, const size_t ldb, 42 const typename Field::Element beta, 43 typename Field::Element * C, const size_t ldc, 44 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Block, StrategyParameter::Threads> > & H){ 45 { 46 H.parseq.set_numthreads( std::min(H.parseq.numthreads(), std::max((size_t)1,(size_t)(m*n/(__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD)))) ); 47 48 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Sequential> SeqH (H); 49 size_t sa = (ta==FFLAS::FflasNoTrans)?lda:1; 50 size_t sb = (tb==FFLAS::FflasNoTrans)?1:ldb; 51 SYNCH_GROUP({FORBLOCK2D(iter,m,n,H.parseq, 52 TASK( MODE( 53 READ(A[iter.ibegin()*sa],B[iter.jbegin()*sb]) 54 CONSTREFERENCE(F, SeqH) 55 READWRITE(C[iter.ibegin()*ldc+iter.jbegin()])), 56 fgemm( F, ta, tb, iter.iend()-iter.ibegin(), iter.jend()-iter.jbegin(), k, alpha, A+iter.ibegin()*sa, lda, B+iter.jbegin()*sb, ldb, beta, C+iter.ibegin()*ldc+iter.jbegin(), ldc, SeqH);); 57 ); 58 }); 59 } 60 return C; 61 62 63 } 64 65 template<class Field, class AlgoT, class FieldTrait> 66 typename Field::Element* 67 pfgemm(const Field& F, 68 const FFLAS_TRANSPOSE ta, 69 const FFLAS_TRANSPOSE tb, 70 const size_t m, 71 const size_t n, 72 const size_t k, 73 const typename Field::Element alpha, 74 const typename Field::ConstElement_ptr AA, const size_t lda, 75 const typename Field::ConstElement_ptr BB, const size_t ldb, 76 const typename Field::Element beta, 77 typename Field::Element * C, const size_t ldc, 78 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::ThreeDAdaptive> > & H){ 79 80 typename Field::Element a = alpha; 81 typename Field::Element b = beta; 82 typename Field::ConstElement_ptr B = BB; 83 typename Field::ConstElement_ptr A = AA; 84 if (!m || !n) {return C;} 85 if (!k || F.isZero (alpha)){ 86 fscalin(F, m, n, beta, C, ldc); 87 return C; 88 } 89 90 if (H.parseq.numthreads()<=1 || std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){ 91 MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H); 92 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH); 93 } 94 95 typedef MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::ThreeDAdaptive> > MMH_t; 96 MMH_t H1(H); 97 MMH_t H2(H); 98 if(__FFLASFFPACK_DIMKPENALTY*m > k && m >= n) { 99 SYNCH_GROUP(size_t M2= m>>1; 100 H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1); 101 H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads()); 102 103 typename Field::ConstElement_ptr A1= A; 104 typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda); 105 typename Field::Element_ptr C1= C; 106 typename Field::Element_ptr C2= C+M2*ldc; 107 108 // 2 multiply (1 split on dimension m) 109 110 TASK(MODE(CONSTREFERENCE(F, H1) READ(A1,B) READWRITE(C1)), 111 {pfgemm( F, ta, tb, M2, n, k, alpha, A1, lda, B, ldb, beta, C1, ldc, H1);} 112 ); 113 114 TASK(MODE(CONSTREFERENCE(F,H2) READ(A2,B) READWRITE(C2)), 115 {pfgemm(F, ta, tb, m-M2, n, k, alpha, A2, lda, B, ldb, beta, C2, ldc, H2);} 116 ); 117 ); 118 119 } else if (__FFLASFFPACK_DIMKPENALTY*n > k) { 120 SYNCH_GROUP( 121 size_t N2 = n>>1; 122 H1.parseq.set_numthreads( H1.parseq.numthreads() >> 1); 123 H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads()); 124 typename Field::ConstElement_ptr B1= B; 125 typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1); 126 127 typename Field::Element_ptr C1= C; 128 typename Field::Element_ptr C2= C+N2; 129 130 TASK(MODE(CONSTREFERENCE(F,H1) READ(A,B1) READWRITE(C1)), pfgemm(F, ta, tb, m, N2, k, a, A, lda, B1, ldb, b, C1, ldc, H1)); 131 TASK(MODE(CONSTREFERENCE(F,H2) READ(A,B2) READWRITE(C2)), pfgemm(F, ta, tb, m, n-N2, k, a, A, lda, B2, ldb, b,C2, ldc, H2)); 132 ); 133 134 } else { 135 136 size_t K2 = k>>1; 137 138 typename Field::ConstElement_ptr B1= B; 139 typename Field::ConstElement_ptr B2= B+K2*((tb==FFLAS::FflasTrans)?1:ldb); 140 typename Field::ConstElement_ptr A1= A; 141 typename Field::ConstElement_ptr A2= A+K2*((ta==FFLAS::FflasTrans)?lda:1); 142 typename Field::Element_ptr C2 = fflas_new (F, m, n,Alignment::CACHE_PAGESIZE); 143 144 H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1); 145 H2.parseq.set_numthreads(H.parseq.numthreads()-H1.parseq.numthreads()); 146 SYNCH_GROUP( 147 TASK(MODE(CONSTREFERENCE(F,H1) READ(A1,B1) READWRITE(C)), pfgemm(F, ta, tb, m, n, K2, a, A1, lda, B1, ldb, b, C, ldc, H1)); 148 149 TASK(MODE(CONSTREFERENCE(F,H2) READ(A2,B2) READWRITE(C2)), pfgemm(F, ta, tb, m, n, k-K2, a, A2, lda, B2, ldb, F.zero, C2, n, H2)); 150 CHECK_DEPENDENCIES; 151 152 TASK(MODE(CONSTREFERENCE(F) READ(C2) READWRITE(C)),faddin(F, n, m, C2, n, C, ldc)); 153 154 ); 155 fflas_delete(C2); 156 } 157 158 return C; 159 } 160 161 template<class Field, class AlgoT, class FieldTrait> 162 typename Field::Element* 163 pfgemm (const Field& F, 164 const FFLAS_TRANSPOSE ta, 165 const FFLAS_TRANSPOSE tb, 166 const size_t m, 167 const size_t n, 168 const size_t k, 169 const typename Field::Element alpha, 170 const typename Field::ConstElement_ptr AA, const size_t lda, 171 const typename Field::ConstElement_ptr BB, const size_t ldb, 172 const typename Field::Element beta, 173 typename Field::Element * C, const size_t ldc, 174 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoDAdaptive> > & H){ 175 176 typename Field::Element a = alpha; 177 typename Field::Element b = beta; 178 typename Field::ConstElement_ptr B = BB; 179 typename Field::ConstElement_ptr A = AA; 180 if (!m || !n) {return C;} 181 if (!k || F.isZero (alpha)){ 182 fscalin(F, m, n, beta, C, ldc); 183 return C; 184 } 185 if (H.parseq.numthreads()<=1 || m*n<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){ 186 MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H); 187 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH); 188 189 } 190 typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::TwoDAdaptive> > MMH_t; 191 MMH_t H1(H); 192 MMH_t H2(H); 193 H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1); 194 H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads()); 195 if(m >= n) { 196 size_t M2= m>>1; 197 typename Field::ConstElement_ptr A1= A; 198 typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda); 199 typename Field::Element_ptr C1= C; 200 typename Field::Element_ptr C2= C+M2*ldc; 201 SYNCH_GROUP( 202 TASK(MODE(CONSTREFERENCE(F,H1, A1, B) READ(M2, A1[0],B[0]) READWRITE(C1[0])), pfgemm(F, ta, tb, M2, n, k, alpha, A1, lda, B, ldb, beta, C1, ldc, H1)); 203 TASK(MODE(CONSTREFERENCE(F,H2, A2, B) READ(M2, A2[0],B[0]) READWRITE(C2[0])), pfgemm(F, ta, tb, m-M2, n, k, alpha, A2, lda, B, ldb, beta, C2, ldc, H2)); 204 205 ); 206 207 } else { 208 size_t N2 = n>>1; 209 typename Field::ConstElement_ptr B1= B; 210 typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1); 211 typename Field::Element_ptr C1= C; 212 typename Field::Element_ptr C2= C+N2; 213 SYNCH_GROUP( 214 TASK(MODE(CONSTREFERENCE(F,H1, A, B1) READ(N2, A[0], B1[0]) READWRITE(C1[0])), pfgemm(F, ta, tb, m, N2, k, a, A, lda, B1, ldb, b, C1, ldc, H1)); 215 TASK(MODE(CONSTREFERENCE(F,H2, A, B2) READ(N2, A[0], B2[0]) READWRITE(C2[0])), pfgemm(F, ta, tb, m, n-N2, k, a, A, lda, B2, ldb, b,C2, ldc, H2)); 216 ); 217 } 218 return C; 219 } 220 221 template<class Field, class AlgoT, class FieldTrait> 222 typename Field::Element* 223 pfgemm( const Field& F, 224 const FFLAS_TRANSPOSE ta, 225 const FFLAS_TRANSPOSE tb, 226 const size_t m, 227 const size_t n, 228 const size_t k, 229 const typename Field::Element alpha, 230 const typename Field::ConstElement_ptr AA, const size_t lda, 231 const typename Field::ConstElement_ptr BB, const size_t ldb, 232 const typename Field::Element beta, 233 typename Field::Element * C, const size_t ldc, 234 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoD> > & H){ 235 236 typename Field::Element a = alpha; 237 typename Field::Element b = beta; 238 typename Field::ConstElement_ptr B = BB; 239 typename Field::ConstElement_ptr A = AA; 240 if (!m || !n) {return C;} 241 if (!k || F.isZero (alpha)){ 242 fscalin(F, m, n, beta, C, ldc); 243 return C; 244 } 245 246 if(H.parseq.numthreads()<=1|| m*n<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){ 247 MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H); 248 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH); 249 } else 250 { 251 size_t M2= m>>1; 252 size_t N2= n>>1; 253 254 typename Field::ConstElement_ptr A1= A; 255 typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda); 256 typename Field::ConstElement_ptr B1= B; 257 typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1); 258 259 typename Field::Element_ptr C11= C; 260 typename Field::Element_ptr C21= C+M2*ldc; 261 typename Field::Element_ptr C12= C+N2; 262 typename Field::Element_ptr C22= C+N2+M2*ldc; 263 264 typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoD> > MMH_t; 265 MMH_t H1(H); 266 MMH_t H2(H); 267 MMH_t H3(H); 268 MMH_t H4(H); 269 size_t nt = H.parseq.numthreads(); 270 size_t nt_rec = nt/4; 271 size_t nt_mod = nt%4; 272 H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 273 H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 274 H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 275 H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 276 SYNCH_GROUP( 277 TASK(MODE(CONSTREFERENCE(F,H1) READ(A1,B1) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, k, alpha, A1, lda, B1, ldb, beta, C11, ldc, H1)); 278 279 TASK(MODE(CONSTREFERENCE(F,H2) READ(A1,B2) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k, alpha, A1, lda, B2, ldb, beta, C12, ldc, H2)); 280 281 TASK(MODE(CONSTREFERENCE(F,H3) READ(A2,B1) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k, a, A2, lda, B1, ldb, b, C21, ldc, H3)); 282 283 TASK(MODE(CONSTREFERENCE(F,H4) READ(A2,B2) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, k, a, A2, lda, B2, ldb, b,C22, ldc, H4)); 284 ); 285 } 286 return C; 287 } 288 289 290 291 template<class Field, class AlgoT, class FieldTrait> 292 typename Field::Element_ptr 293 pfgemm(const Field& F, 294 const FFLAS_TRANSPOSE ta, 295 const FFLAS_TRANSPOSE tb, 296 const size_t m, 297 const size_t n, 298 const size_t k, 299 const typename Field::Element alpha, 300 const typename Field::ConstElement_ptr A, const size_t lda, 301 const typename Field::ConstElement_ptr B, const size_t ldb, 302 const typename Field::Element beta, 303 typename Field::Element_ptr C, const size_t ldc, 304 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeD> > & H){ 305 306 307 if (!m || !n) {return C;} 308 if (!k || F.isZero (alpha)){ 309 fscalin(F, m, n, beta, C, ldc); 310 return C; 311 } 312 if(H.parseq.numthreads() <= 1|| std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){ 313 FFLAS::MMHelper<Field, AlgoT, FieldTrait,FFLAS::ParSeqHelper::Sequential> WH (H); 314 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, WH); 315 } 316 else 317 { 318 typename Field::Element a = alpha; 319 typename Field::Element b = 0; 320 321 size_t M2= m>>1; 322 size_t N2= n>>1; 323 size_t K2= k>>1; 324 typename Field::ConstElement_ptr A11= A; 325 typename Field::ConstElement_ptr A12= A+K2*((ta==FFLAS::FflasTrans)?lda:1); 326 typename Field::ConstElement_ptr A21= A+M2*((ta==FFLAS::FflasTrans)?1:lda); 327 typename Field::ConstElement_ptr A22= A12+M2*((ta==FFLAS::FflasTrans)?1:lda); 328 329 typename Field::ConstElement_ptr B11= B; 330 typename Field::ConstElement_ptr B12= B+N2*((tb==FFLAS::FflasTrans)?ldb:1); 331 typename Field::ConstElement_ptr B21= B+K2*((tb==FFLAS::FflasTrans)?1:ldb); 332 typename Field::ConstElement_ptr B22= B12+K2*((tb==FFLAS::FflasTrans)?1:ldb); 333 334 typename Field::Element_ptr C11= C; 335 typename Field::Element_ptr C_11 = fflas_new (F, M2, N2,Alignment::CACHE_PAGESIZE); 336 337 typename Field::Element_ptr C12= C+N2; 338 typename Field::Element_ptr C_12 = fflas_new (F, M2, n-N2,Alignment::CACHE_PAGESIZE); 339 340 typename Field::Element_ptr C21= C+M2*ldc; 341 typename Field::Element_ptr C_21 = fflas_new (F, m-M2, N2,Alignment::CACHE_PAGESIZE); 342 343 typename Field::Element_ptr C22= C+N2+M2*ldc; 344 typename Field::Element_ptr C_22 = fflas_new (F, m-M2, n-N2,Alignment::CACHE_PAGESIZE); 345 346 // 1/ 8 multiply in parallel 347 //omp_set_task_affinity(omp_get_locality_domain_num_for( C11)); 348 349 typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeD> > MMH_t; 350 MMH_t H1(H); 351 MMH_t H2(H); 352 MMH_t H3(H); 353 MMH_t H4(H); 354 MMH_t H5(H); 355 MMH_t H6(H); 356 MMH_t H7(H); 357 MMH_t H8(H); 358 size_t nt = H.parseq.numthreads(); 359 size_t nt_rec = nt/8; 360 size_t nt_mod = nt % 8 ; 361 H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 362 H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 363 H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 364 H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 365 H5.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 366 H6.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 367 H7.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 368 H8.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 369 370 SYNCH_GROUP( 371 TASK(MODE(CONSTREFERENCE(F,H1) READ(A11,B11) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, K2, alpha, A11, lda, B11, ldb, beta, C11, ldc, H1)); 372 //omp_set_task_affinity(omp_get_locality_domain_num_for( C_11)); 373 TASK(MODE(CONSTREFERENCE(F,H2) READ(A12,B21) WRITE(C_11)), pfgemm(F, ta, tb, M2, N2, k-K2, a, A12, lda, B21, ldb, b,C_11, N2, H2)); 374 //omp_set_task_affinity(omp_get_locality_domain_num_for( C12)); 375 TASK(MODE(CONSTREFERENCE(F,H3) READ(A12,B22) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k-K2, alpha, A12, lda, B22, ldb, beta, C12, ldc, H3)); 376 //omp_set_task_affinity(omp_get_locality_domain_num_for( C_12)); 377 TASK(MODE(CONSTREFERENCE(F,H4) READ(A11,B12) WRITE(C_12)), pfgemm(F, ta, tb, M2, n-N2, K2, a, A11, lda, B12, ldb, b, C_12, n-N2, H4)); 378 //omp_set_task_affinity(omp_get_locality_domain_num_for( C21)); 379 TASK(MODE(CONSTREFERENCE(F,H5) READ(A22,B21) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k-K2, alpha, A22, lda, B21, ldb, beta, C21, ldc, H5)); 380 //omp_set_task_affinity(omp_get_locality_domain_num_for( C_21)); 381 TASK(MODE(CONSTREFERENCE(F,H6) READ(A21,B11) WRITE(C_21)), pfgemm(F, ta, tb, m-M2, N2, K2, a, A21, lda, B11, ldb, b,C_21, N2, H6)); 382 //omp_set_task_affinity(omp_get_locality_domain_num_for( C22)); 383 TASK(MODE(CONSTREFERENCE(F,H7) READ(A21,B12) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, K2, alpha, A21, lda, B12, ldb, beta, C22, ldc, H7)); 384 //omp_set_task_affinity(omp_get_locality_domain_num_for( C_22)); 385 TASK(MODE(CONSTREFERENCE(F,H8) READ(A22,B22) WRITE(C_22)), pfgemm(F, ta, tb, m-M2, n-N2, k-K2, a, A22, lda, B22, ldb, b,C_22, n-N2, H8)); 386 387 CHECK_DEPENDENCIES; 388 // 2/ final add 389 //omp_set_task_affinity(omp_get_locality_domain_num_for( C11)); 390 TASK(MODE(CONSTREFERENCE(F) READ(C_11) READWRITE(C11)), faddin(F, M2, N2, C_11, N2, C11, ldc)); 391 //omp_set_task_affinity(omp_get_locality_domain_num_for( C12)); 392 TASK(MODE(CONSTREFERENCE(F) READ(C_12) READWRITE(C12)),faddin(F, M2, n-N2, C_12, n-N2, C12, ldc)); 393 //omp_set_task_affinity(omp_get_locality_domain_num_for( C21)); 394 TASK(MODE(CONSTREFERENCE(F) READ(C_21) READWRITE(C21)), faddin(F, m-M2, N2, C_21, N2, C21, ldc)); 395 //omp_set_task_affinity(omp_get_locality_domain_num_for( C22)); 396 TASK(MODE(CONSTREFERENCE(F) READ(C_22) READWRITE(C22)), faddin(F, m-M2, n-N2, C_22, n-N2, C22, ldc)); 397 398 ); 399 FFLAS::fflas_delete (C_11); 400 FFLAS::fflas_delete (C_12); 401 FFLAS::fflas_delete (C_21); 402 FFLAS::fflas_delete (C_22); 403 } 404 return C; 405 } 406 407 template<class Field, class AlgoT, class FieldTrait> 408 typename Field::Element* 409 pfgemm( const Field& F, 410 const FFLAS_TRANSPOSE ta, 411 const FFLAS_TRANSPOSE tb, 412 const size_t m, 413 const size_t n, 414 const size_t k, 415 const typename Field::Element alpha, 416 const typename Field::ConstElement_ptr A, const size_t lda, 417 const typename Field::ConstElement_ptr B, const size_t ldb, 418 const typename Field::Element beta, 419 typename Field::Element_ptr C, const size_t ldc, 420 MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeDInPlace> > & H){ 421 422 423 if (!m || !n) {return C;} 424 if (!k || F.isZero (alpha)){ 425 fscalin(F, m, n, beta, C, ldc); 426 return C; 427 } 428 429 if(H.parseq.numthreads() <= 1|| std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){ // threshold 430 FFLAS::MMHelper<Field, AlgoT, FieldTrait,FFLAS::ParSeqHelper::Sequential> WH (H); 431 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, WH); 432 }else{ 433 size_t M2= m>>1; 434 size_t N2= n>>1; 435 size_t K2= k>>1; 436 typename Field::ConstElement_ptr A11= A; 437 typename Field::ConstElement_ptr A12= A+K2*((ta==FFLAS::FflasTrans)?lda:1); 438 typename Field::ConstElement_ptr A21= A+M2*((ta==FFLAS::FflasTrans)?1:lda); 439 typename Field::ConstElement_ptr A22= A12+M2*((ta==FFLAS::FflasTrans)?1:lda); 440 441 typename Field::ConstElement_ptr B11= B; 442 typename Field::ConstElement_ptr B12= B+N2*((tb==FFLAS::FflasTrans)?ldb:1); 443 typename Field::ConstElement_ptr B21= B+K2*((tb==FFLAS::FflasTrans)?1:ldb); 444 typename Field::ConstElement_ptr B22= B12+K2*((tb==FFLAS::FflasTrans)?1:ldb); 445 446 447 typename Field::Element_ptr C11= C; 448 typename Field::Element_ptr C12= C+N2; 449 typename Field::Element_ptr C21= C+M2*ldc; 450 typename Field::Element_ptr C22= C+N2+M2*ldc; 451 typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeDInPlace> > MMH_t; 452 MMH_t H1(H); 453 MMH_t H2(H); 454 MMH_t H3(H); 455 MMH_t H4(H); 456 size_t nt = H.parseq.numthreads(); 457 size_t nt_rec = nt/4; 458 size_t nt_mod = nt%4; 459 H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 460 H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 461 H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 462 H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 463 SYNCH_GROUP( 464 // 1/ 4 multiply 465 TASK(MODE(CONSTREFERENCE(F,H1) READ(A11,B11) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, K2, alpha, A11, lda, B11, ldb, beta, C11, ldc, H1)); 466 TASK(MODE(CONSTREFERENCE(F,H2) READ(A12,B22) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k-K2, alpha, A12, lda, B22, ldb, beta, C12, ldc, H2)); 467 TASK(MODE(CONSTREFERENCE(F,H3) READ(A22,B21) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k-K2, alpha, A22, lda, B21, ldb, beta, C21, ldc, H3)); 468 TASK(MODE(CONSTREFERENCE(F,H4) READ(A21,B12) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, K2, alpha, A21, lda, B12, ldb, beta, C22, ldc, H4)); 469 470 CHECK_DEPENDENCIES; 471 // 2/ 4 add+multiply 472 TASK(MODE(CONSTREFERENCE(F,H1) READ(A12,B21) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, k-K2, alpha, A12, lda, B21, ldb, F.one, C11, ldc, H1)); 473 TASK(MODE(CONSTREFERENCE(F,H2) READ(A11,B12) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, K2, alpha, A11, lda, B12, ldb, F.one, C12, ldc, H2)); 474 TASK(MODE(CONSTREFERENCE(F,H3) READ(A21,B11) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, K2, alpha, A21, lda, B11, ldb, F.one, C21, ldc, H3)); 475 TASK(MODE(CONSTREFERENCE(F,H4) READ(A22,B22) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, k-K2, alpha, A22, lda, B22, ldb, F.one, C22, ldc, H4)); 476 ); 477 } 478 return C; 479 } 480 481 482 483} // FFLAS 484/* -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ 485// vim:sts=4:sw=4:ts=4:et:sr:cino=>s,f0,{0,g0,(0,\:0,t0,+0,=s 486