1/* 2 * Copyright (C) 2014 the LinBox group 3 * 4 * Written by Clement Pernet <Clement.Pernet@imag.fr> 5 * Brice Boyer (briceboyer) <boyer.brice@gmail.com> 6 * Ziad Sultan <ziad.sultan@imag.fr> 7 * 8 * ========LICENCE======== 9 * This file is part of the library FFLAS-FFPACK. 10 * 11 * FFLAS-FFPACK is free software: you can redistribute it and/or modify 12 * it under the terms of the GNU Lesser General Public 13 * License as published by the Free Software Foundation; either 14 * version 2.1 of the License, or (at your option) any later version. 15 * 16 * This library is distributed in the hope that it will be useful, 17 * but WITHOUT ANY WARRANTY; without even the implied warranty of 18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 19 * Lesser General Public License for more details. 20 * 21 * You should have received a copy of the GNU Lesser General Public 22 * License along with this library; if not, write to the Free Software 23 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 24 * ========LICENCE======== 25 *. 26 */ 27 28/** @file fflas/fflas_fgemm/winograd.inl 29 * @ingroup MMalgos 30 * @brief Winograd implementation 31 * @bib ISSAC09 Scheduling 32 */ 33 34#ifndef __FFLASFFPACK_fgemm_winograd_INL 35#define __FFLASFFPACK_fgemm_winograd_INL 36 37namespace FFLAS { namespace BLAS3 { 38 39 template < class Field, class FieldTrait, class Strat, class Param > 40 inline typename Field::Element_ptr 41 WinoPar (const Field& F, 42 const FFLAS_TRANSPOSE ta, 43 const FFLAS_TRANSPOSE tb, 44 const size_t mr, const size_t nr, const size_t kr, 45 const typename Field::Element alpha, 46 typename Field::ConstElement_ptr A,const size_t lda, 47 typename Field::ConstElement_ptr B,const size_t ldb, 48 const typename Field::Element beta, 49 typename Field::Element_ptr C, const size_t ldc, 50 // const size_t kmax, const size_t w, const FFLAS_BASE base 51 MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait, ParSeqHelper::Parallel<Strat,Param> > & WH 52 ) 53 { 54 FFLASFFPACK_check(F.isZero(beta)); 55 56 // typedef MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait > MMH_t; 57 typedef MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoDAdaptive> > MMH_t; 58 const typename MMH_t::DelayedField & DF = WH.delayedField; 59 typedef typename MMH_t::DelayedField::Element DFElt; 60 61 size_t lb, cb, la, ca, ldX2; 62 // size_t x3rd = std::max(mr,kr); 63 typename Field::ConstElement_ptr A11=A, A12, A21, A22; 64 typename Field::ConstElement_ptr B11=B, B12, B21, B22; 65 typename Field::Element_ptr C11=C, C12=C+nr, C21=C+mr*ldc, C22=C21+nr; 66 67 size_t x1rd = std::max(nr,kr); 68 size_t ldX1; 69 if (ta == FflasTrans) { 70 A21 = A + mr; 71 A12 = A + kr*lda; 72 A22 = A12 + mr; 73 la = kr; 74 ca = mr; 75 ldX1 = mr; 76 } else { 77 A12 = A + kr; 78 A21 = A + mr*lda; 79 A22 = A21 + kr; 80 la = mr; 81 ca = kr; 82 ldX1 = x1rd; 83 } 84 if (tb == FflasTrans) { 85 B21 = B + kr; 86 B12 = B + nr*ldb; 87 B22 = B12 + kr; 88 lb = nr; 89 cb = kr; 90 ldX2 = kr; 91 } else { 92 B12 = B + nr; 93 B21 = B + kr*ldb; 94 B22 = B21 + nr; 95 lb = kr; 96 ldX2 = cb = nr; 97 } 98 99 // 11 temporary submatrices are required 100 typename Field::Element_ptr X21 = fflas_new (F, kr, nr); 101 typename Field::Element_ptr X11 = fflas_new (F,mr,x1rd); 102 103 typename Field::Element_ptr X22 = fflas_new (F, kr, nr); 104 typename Field::Element_ptr X12 = fflas_new (F,mr,x1rd); 105 106 typename Field::Element_ptr X23 = fflas_new (F, kr, nr); 107 typename Field::Element_ptr X13 = fflas_new (F,mr,x1rd); 108 109 typename Field::Element_ptr X24 = fflas_new (F, kr, nr); 110 typename Field::Element_ptr X14 = fflas_new (F,mr,x1rd); 111 typename Field::Element_ptr X15 = fflas_new (F,mr,x1rd); 112 113 typename Field::Element_ptr C_11 = fflas_new (F,mr,nr); 114 typename Field::Element_ptr CC_11 = fflas_new (F,mr,nr); 115 SYNCH_GROUP( 116 117 // T3 = B22 - B12 in X21 and S3 = A11 - A21 in X11 118 TASK(MODE(READ(B22, B12) WRITE(X21) CONSTREFERENCE(DF)), 119 pfsub(DF,lb,cb,B22,ldb,B12,ldb,X21,ldX2, NUM_THREADS);); 120 TASK(MODE(READ(A11, A21) WRITE(X11) CONSTREFERENCE(DF)), 121 pfsub(DF,la,ca,A11,lda,A21,lda,X11,ldX1, NUM_THREADS);); 122 123 // T1 = B12 - B11 in X22 and S1 = A21 + A22 in X12 124 TASK(MODE(READ(B11, B12) WRITE(X22) CONSTREFERENCE(DF)), 125 pfsub(DF,lb,cb,B12,ldb,B11,ldb,X22,ldX2, NUM_THREADS);); 126 TASK(MODE(READ(A12, A22) WRITE(X12) CONSTREFERENCE(DF)), 127 pfadd(DF,la,ca,A21,lda,A22,lda,X12,ldX1, NUM_THREADS);); 128 129 CHECK_DEPENDENCIES; 130 131 // T2 = B22 - T1 in X23 and S2 = S1 - A11 in X13 132 TASK(MODE(READ(B22, X22) READWRITE(X23) CONSTREFERENCE(DF)), 133 pfsub(DF,lb,cb,B22,ldb,X22,ldX2,X23,ldX2, NUM_THREADS);); 134 TASK(MODE(READ(A11, X12) READWRITE(X13) CONSTREFERENCE(DF)), 135 // fsub(DF,la,ca,A11,lda,X12,ldX1,X13,ldX1);); 136 pfsub(DF,la,ca,X12,ldX1,A11,lda,X13,ldX1, NUM_THREADS);); 137 /* 138 fsub(DF,lb,cb,B22,ldb,X2,ldX2,X2,ldX2); 139 fsubin(DF,la,ca,A11,lda,X1,ldX1);); 140 */ 141 CHECK_DEPENDENCIES; 142 143 // T4 = T2 - B21 in X2 and S4 = A12 -S2 in X1 144 TASK(MODE(READ(B21, X23) READWRITE(X24) CONSTREFERENCE(DF)), 145 // fsub(DF,lb,cb,B21,ldb,X23,ldX2,X24,ldX2); 146 pfsub(DF,lb,cb,X23,ldX2,B21,ldb,X24,ldX2, NUM_THREADS);); 147 TASK(MODE(READ(A12, X13) READWRITE(X14) CONSTREFERENCE(DF)), 148 pfsub(DF,la,ca,A12,lda,X13,ldX1,X14,ldX1, NUM_THREADS);); 149 150 /* 151 fsubin(DF,lb,cb,B21,ldb,X2,ldX2); 152 fsub(DF,la,ca,A12,lda,X1,ldX1,X1,ldX1);); 153 */ 154 CHECK_DEPENDENCIES; 155 156 // P1 = alpha . A11 * B11 in X1 157 158 MMH_t H1(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0); 159 MMH_t H7(F, WH.recLevel-1, -(WH.Amax-WH.Amin), WH.Amax-WH.Amin, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0,0); 160 MMH_t H5(F, WH.recLevel-1, 2*WH.Amin, 2*WH.Amax, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0, 0); 161 MMH_t H6(F, WH.recLevel-1, 2*WH.Amin-WH.Amax, 2*WH.Amax-WH.Amin, 2*WH.Bmin-WH.Bmax, 2*WH.Bmax-WH.Bmin, 0, 0); 162 MMH_t H3(F, WH.recLevel-1, 2*WH.Amin-2*WH.Amax, 2*WH.Amax-2*WH.Amin, WH.Bmin, WH.Bmax, 0, 0); 163 MMH_t H4(F, WH.recLevel-1, WH.Amin, WH.Amax, 2*WH.Bmin-2*WH.Bmax, 2*WH.Bmax-2*WH.Bmin, 0, 0); 164 MMH_t H2(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0); 165 166 size_t nt = WH.parseq.numthreads(); 167 size_t nt_rec = nt/7; 168 size_t nt_mod = nt % 7 ; 169 H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 170 H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 171 H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 172 H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 173 H5.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 174 H6.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 175 H7.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0))); 176 177 TASK(MODE(READ(A11, B11) WRITE(X15) CONSTREFERENCE(F,H1)), 178 fgemm (F, ta, tb, mr, nr, kr, alpha, A11, lda, B11, ldb, F.zero, X15, x1rd, H1);); 179 // P7 = alpha . S3 * T3 in C21 180 TASK(MODE(READ(X11, X21) WRITE(C21) CONSTREFERENCE(F,H7)), 181 fgemm (F, ta, tb, mr, nr, kr, alpha, X11, ldX1, X21, ldX2, F.zero, C21, ldc, H7);); 182 183 // P5 = alpha . S1*T1 in C22 184 TASK(MODE(READ(X12, X22) WRITE(C22) CONSTREFERENCE(F,H5)), 185 fgemm (F, ta, tb, mr, nr, kr, alpha, X12, ldX1, X22, ldX2, F.zero, C22, ldc, H5);); 186 187 // P6 = alpha . S2 * T2 in C12 188 TASK(MODE(READ(X13, X23) WRITE(C12) CONSTREFERENCE(F,H6)), 189 fgemm (F, ta, tb, mr, nr, kr, alpha, X13, ldX1, X23, ldX2, F.zero, C12, ldc, H6);); 190 191 // P3 = alpha . S4*B22 in CC_11 192 TASK(MODE(READ(X14, B22) WRITE(CC_11) CONSTREFERENCE(F,H3)), 193 fgemm (F, ta, tb, mr, nr, kr, alpha, X14, ldX1, B22, ldb, F.zero, CC_11, nr, H3);); 194 195 // P4 = alpha . A22 * T4 in C_11 196 TASK(MODE(READ(A22) WRITE(C_11) READWRITE(X24, X22, X23, X21) CONSTREFERENCE(F,H4)), 197 fgemm (F, ta, tb, mr, nr, kr, alpha, A22, lda, X24, ldX2, F.zero, C_11, nr, H4); 198 ); 199 200 // P2 = alpha . A12 * B21 in C11 201 TASK(MODE(READ(A12, B21) WRITE(C11) CONSTREFERENCE(F,H2)), 202 fgemm (F, ta, tb, mr, nr, kr, alpha, A12, lda, B21, ldb, F.zero, C11, ldc, H2);); 203 CHECK_DEPENDENCIES; 204 205 DFElt U2Min, U2Max; 206 DFElt U3Min, U3Max; 207 DFElt U4Min, U4Max; 208 DFElt U7Min, U7Max; 209 DFElt U5Min, U5Max; 210 // U2 = P1 + P6 in C12 and 211 // U3 = P7 + U2 in C21 and 212 // U4 = P5 + U2 in C12 and 213 // U7 = P5 + U3 in C22 and 214 // U5 = P3 + U4 in C12 215 // BIG TASK with 5 Addin function calls 216 // TASK(MODE(READWRITE(X15, C12) CONSTREFERENCE(F, DF, WH, U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax)), 217 if (Protected::NeedPreAddReduction(U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax, WH)){ 218 TASK(MODE(READWRITE(X15) CONSTREFERENCE(F)), 219 pfreduce (F, mr, x1rd, X15, x1rd, NUM_THREADS); 220 ); 221 TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)), 222 pfreduce (F, mr, nr, C12, ldc, NUM_THREADS); 223 ); 224 CHECK_DEPENDENCIES; 225 } 226 TASK(MODE(READWRITE(X15, C12) CONSTREFERENCE(DF)), 227 pfaddin(DF,mr,nr,X15,x1rd,C12,ldc, NUM_THREADS); 228 ); 229 CHECK_DEPENDENCIES; 230 // TASK(MODE(READWRITE(C12, C21) CONSTREFERENCE(F, DF, WH, U3Min, U3Max, U2Min, U2Max)), 231 if (Protected::NeedPreAddReduction(U3Min, U3Max, U2Min, U2Max, H7.Outmin, H7.Outmax, WH)){ 232 TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)), 233 pfreduce (F, mr, nr, C12, ldc, NUM_THREADS); 234 ); 235 TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)), 236 pfreduce (F, mr, nr, C21, ldc, NUM_THREADS); 237 ); 238 CHECK_DEPENDENCIES; 239 } 240 TASK(MODE(READWRITE(C12, C21) CONSTREFERENCE(DF)), 241 pfaddin(DF,mr,nr,C12,ldc,C21,ldc, NUM_THREADS); 242 ); 243 CHECK_DEPENDENCIES; 244 // TASK(MODE(READWRITE(C12, C22) CONSTREFERENCE(F, DF, WH) VALUE(U4Min, U4Max, U2Min, U2Max)), 245 if (Protected::NeedPreAddReduction(U4Min, U4Max, U2Min, U2Max, H5.Outmin, H5.Outmax, WH)){ 246 TASK(MODE(READWRITE(C22) CONSTREFERENCE(F)), 247 pfreduce (F, mr, nr, C22, ldc, NUM_THREADS); 248 ); 249 TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)), 250 pfreduce (F, mr, nr, C12, ldc, NUM_THREADS); 251 ); 252 CHECK_DEPENDENCIES; 253 } 254 TASK(MODE(READWRITE(C12, C22) CONSTREFERENCE(DF, WH)), 255 pfaddin(DF,mr,nr,C22,ldc,C12,ldc, NUM_THREADS); 256 ); 257 CHECK_DEPENDENCIES; 258 // TASK(MODE(READWRITE(C22, C21) CONSTREFERENCE(F, DF, WH) VALUE(U3Min, U3Max, U7Min, U7Max)), 259 if (Protected::NeedPreAddReduction (U7Min,U7Max, U3Min, U3Max, H5.Outmin,H5.Outmax, WH) ){ 260 TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)), 261 pfreduce (F, mr, nr, C21, ldc, NUM_THREADS); 262 ); 263 TASK(MODE(READWRITE(C22) CONSTREFERENCE(F)), 264 pfreduce (F, mr, nr, C22, ldc, NUM_THREADS); 265 ); 266 CHECK_DEPENDENCIES; 267 } 268 TASK(MODE(READWRITE(C22, C21) CONSTREFERENCE(DF, WH)), 269 pfaddin(DF,mr,nr,C21,ldc,C22,ldc, NUM_THREADS); 270 ); 271 // TASK(MODE(READWRITE(C12, CC_11) CONSTREFERENCE(F, DF, WH) VALUE(U5Min, U5Max, U4Min, U4Max)), 272 if (Protected::NeedPreAddReduction (U5Min,U5Max, U4Min, U4Max, H3.Outmin, H3.Outmax, WH) ){ 273 TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)), 274 pfreduce (F, mr, nr, C12, ldc, NUM_THREADS); 275 ); 276 TASK(MODE(READWRITE(CC_11) CONSTREFERENCE(F)), 277 pfreduce (F, mr, nr, CC_11, nr, NUM_THREADS); 278 ); 279 CHECK_DEPENDENCIES; 280 } 281 TASK(MODE(READWRITE(C12, CC_11) CONSTREFERENCE(DF, WH)), 282 pfaddin(DF,mr,nr,CC_11,nr,C12,ldc, NUM_THREADS); 283 ); 284 CHECK_DEPENDENCIES; 285 286 // U6 = U3 - P4 in C21 287 DFElt U6Min, U6Max; 288 // TASK(MODE(READWRITE(C_11, C21) CONSTREFERENCE(F, DF, WH) VALUE(U6Min, U6Max, U3Min, U3Max)), 289 if (Protected::NeedPreSubReduction (U6Min,U6Max, U3Min, U3Max, H4.Outmin,H4.Outmax, WH) ){ 290 TASK(MODE(READWRITE(CC_11) CONSTREFERENCE(F)), 291 pfreduce (F, mr, nr, C_11, nr, NUM_THREADS); 292 ); 293 TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)), 294 pfreduce (F, mr, nr, C21, ldc, NUM_THREADS); 295 ); 296 CHECK_DEPENDENCIES 297 } 298 TASK(MODE(READWRITE(C_11, C21) CONSTREFERENCE(DF, WH) ), 299 pfsubin(DF,mr,nr,C_11,nr,C21,ldc, NUM_THREADS); 300 ); 301 302 //CHECK_DEPENDENCIES; 303 304 // U1 = P2 + P1 in C11 305 DFElt U1Min, U1Max; 306 // TASK(MODE(READWRITE(C11, X15/*, X14, X13, X12, X11*/) CONSTREFERENCE(F, DF, WH) VALUE(U1Min, U1Max)), 307 if (Protected::NeedPreAddReduction (U1Min, U1Max, H1.Outmin, H1.Outmax, H2.Outmin,H2.Outmax, WH) ){ 308 TASK(MODE(READWRITE(X15) CONSTREFERENCE(F)), 309 pfreduce (F, mr, nr, X15, x1rd, NUM_THREADS); 310 ); 311 TASK(MODE(READWRITE(C11) CONSTREFERENCE(F)), 312 pfreduce (F, mr, nr, C11, ldc, NUM_THREADS); 313 ); 314 CHECK_DEPENDENCIES 315 } 316 TASK(MODE(READWRITE(C11, X15) CONSTREFERENCE(DF, WH)), 317 pfaddin(DF,mr,nr,X15,x1rd,C11,ldc, NUM_THREADS); 318 ); 319 320 WH.Outmin = std::min (U1Min, std::min (U5Min, std::min (U6Min, U7Min))); 321 WH.Outmax = std::max (U1Max, std::max (U5Max, std::max (U6Max, U7Max))); 322 323 ); 324 // WAIT; 325 326 327 fflas_delete (CC_11); 328 fflas_delete (C_11); 329 fflas_delete (X15); 330 fflas_delete (X14); 331 fflas_delete (X24); 332 fflas_delete (X13); 333 fflas_delete (X23); 334 fflas_delete (X12); 335 fflas_delete (X22); 336 fflas_delete (X11); 337 fflas_delete (X21); 338 339 return C; 340 } //wino parallel 341 342 343 template < class Field, class FieldTrait > 344 inline void Winograd (const Field& F, 345 const FFLAS_TRANSPOSE ta, 346 const FFLAS_TRANSPOSE tb, 347 const size_t mr, const size_t nr, const size_t kr, 348 const typename Field::Element alpha, 349 typename Field::ConstElement_ptr A,const size_t lda, 350 typename Field::ConstElement_ptr B,const size_t ldb, 351 const typename Field::Element beta, 352 typename Field::Element_ptr C, const size_t ldc, 353 // const size_t kmax, const size_t w, const FFLAS_BASE base 354 MMHelper<Field, MMHelperAlgo::Winograd, FieldTrait> & WH 355 ) 356 { 357 FFLASFFPACK_check(F.isZero(beta)); 358 359 typedef MMHelper<Field, MMHelperAlgo::Winograd, FieldTrait > MMH_t; 360 typedef typename MMH_t::DelayedField::Element_ptr DFEptr; 361 typedef typename MMH_t::DelayedField::ConstElement_ptr DFCEptr; 362 typedef typename MMH_t::DelayedField::Element DFElt; 363 364 const typename MMH_t::DelayedField & DF = WH.delayedField; 365 366 size_t lb, cb, la, ca, ldX2; 367 // size_t x3rd = std::max(mr,kr); 368 typename Field::ConstElement_ptr A11=A, A12, A21, A22; 369 typename Field::ConstElement_ptr B11=B, B12, B21, B22; 370 typename Field::Element_ptr C11=C, C12=C+nr, C21=C+mr*ldc, C22=C21+nr; 371 372 size_t x1rd = std::max(nr,kr); 373 size_t ldX1; 374 if (ta == FflasTrans) { 375 A21 = A + mr; 376 A12 = A + kr*lda; 377 A22 = A12 + mr; 378 la = kr; 379 ca = mr; 380 ldX1 = mr; 381 } else { 382 A12 = A + kr; 383 A21 = A + mr*lda; 384 A22 = A21 + kr; 385 la = mr; 386 ca = kr; 387 ldX1 = x1rd; 388 } 389 if (tb == FflasTrans) { 390 B21 = B + kr; 391 B12 = B + nr*ldb; 392 B22 = B12 + kr; 393 lb = nr; 394 cb = kr; 395 ldX2 = kr; 396 } else { 397 B12 = B + nr; 398 B21 = B + kr*ldb; 399 B22 = B21 + nr; 400 lb = kr; 401 ldX2 = cb = nr; 402 } 403 // Two temporary submatrices are required 404 typename Field::Element_ptr X2 = fflas_new (F, kr, nr); 405 406 // T3 = B22 - B12 in X2 407 fsub(DF,lb,cb, (DFCEptr) B22,ldb, (DFCEptr) B12,ldb, (DFEptr)X2,ldX2); 408 409 // S3 = A11 - A21 in X1 410 typename Field::Element_ptr X1 = fflas_new (F,mr,x1rd); 411 fsub(DF,la,ca,(DFCEptr)A11,lda,(DFCEptr)A21,lda,(DFEptr)X1,ldX1); 412 413 // P7 = alpha . S3 * T3 in C21 414 MMH_t H7(F, WH.recLevel-1, -(WH.Amax-WH.Amin), WH.Amax-WH.Amin, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0,0); 415 416 fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C21, ldc, H7); 417 418 // T1 = B12 - B11 in X2 419 fsub(DF,lb,cb,(DFCEptr)B12,ldb,(DFCEptr)B11,ldb,(DFEptr)X2,ldX2); 420 421 // S1 = A21 + A22 in X1 422 fadd(DF,la,ca,(DFCEptr)A21,lda,(DFCEptr)A22,lda,(DFEptr)X1,ldX1); 423 424 // P5 = alpha . S1*T1 in C22 425 MMH_t H5(F, WH.recLevel-1, 2*WH.Amin, 2*WH.Amax, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0, 0); 426 427 fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C22, ldc, H5); 428 429 // T2 = B22 - T1 in X2 430 fsub(DF,lb,cb,(DFCEptr)B22,ldb,(DFCEptr)X2,ldX2,(DFEptr)X2,ldX2); 431 432 // S2 = S1 - A11 in X1 433 fsubin(DF,la,ca,(DFCEptr)A11,lda,(DFEptr)X1,ldX1); 434 435 // P6 = alpha . S2 * T2 in C12 436 MMH_t H6(F, WH.recLevel-1, 2*WH.Amin-WH.Amax, 2*WH.Amax-WH.Amin, 2*WH.Bmin-WH.Bmax, 2*WH.Bmax-WH.Bmin, 0, 0); 437 438 fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C12, ldc, H6); 439 440 // S4 = A12 -S2 in X1 441 fsub(DF,la,ca,(DFCEptr)A12,lda,(DFCEptr)X1,ldX1,(DFEptr)X1,ldX1); 442 443 // P3 = alpha . S4*B22 in C11 444 MMH_t H3(F, WH.recLevel-1, 2*WH.Amin-2*WH.Amax, 2*WH.Amax-2*WH.Amin, WH.Bmin, WH.Bmax, 0, 0); 445 446 fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, B22, ldb, F.zero, C11, ldc, H3); 447 448 // P1 = alpha . A11 * B11 in X1 449 MMH_t H1(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0); 450 451 fgemm (F, ta, tb, mr, nr, kr, alpha, A11, lda, B11, ldb, F.zero, X1, nr, H1); 452 453 // U2 = P1 + P6 in C12 and 454 DFElt U2Min, U2Max; 455 // This test will be optimized out 456 if (Protected::NeedPreAddReduction(U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax, WH)){ 457 freduce (F, mr, nr, X1, nr); 458 freduce (F, mr, nr, C12, ldc); 459 } 460 faddin(DF,mr,nr,(DFCEptr)X1,nr,(DFEptr)C12,ldc); 461 462 // U3 = P7 + U2 in C21 and 463 DFElt U3Min, U3Max; 464 // This test will be optimized out 465 if (Protected::NeedPreAddReduction(U3Min, U3Max, U2Min, U2Max, H7.Outmin, H7.Outmax, WH)){ 466 freduce (F, mr, nr, C12, ldc); 467 freduce (F, mr, nr, C21, ldc); 468 } 469 faddin(DF,mr,nr,(DFCEptr)C12,ldc,(DFEptr)C21,ldc); 470 471 472 // U4 = P5 + U2 in C12 and 473 DFElt U4Min, U4Max; 474 // This test will be optimized out 475 if (Protected::NeedPreAddReduction(U4Min, U4Max, U2Min, U2Max, H5.Outmin, H5.Outmax, WH)){ 476 freduce (F, mr, nr, C22, ldc); 477 freduce (F, mr, nr, C12, ldc); 478 } 479 faddin(DF,mr,nr,(DFCEptr)C22,ldc,(DFEptr)C12,ldc); 480 481 // U7 = P5 + U3 in C22 and 482 DFElt U7Min, U7Max; 483 // This test will be optimized out 484 if (Protected::NeedPreAddReduction (U7Min,U7Max, U3Min, U3Max, H5.Outmin,H5.Outmax, WH) ){ 485 freduce (F, mr, nr, C21, ldc); 486 freduce (F, mr, nr, C22, ldc); 487 } 488 faddin(DF,mr,nr,(DFCEptr)C21,ldc,(DFEptr)C22,ldc); 489 490 // U5 = P3 + U4 in C12 491 DFElt U5Min, U5Max; 492 // This test will be optimized out 493 if (Protected::NeedPreAddReduction (U5Min,U5Max, U4Min, U4Max, H3.Outmin, H3.Outmax, WH) ){ 494 freduce (F, mr, nr, C12, ldc); 495 freduce (F, mr, nr, C11, ldc); 496 } 497 faddin(DF,mr,nr,(DFCEptr)C11,ldc,(DFEptr)C12,ldc); 498 499 // T4 = T2 - B21 in X2 500 fsubin(DF,lb,cb,(DFCEptr)B21,ldb,(DFEptr)X2,ldX2); 501 502 // P4 = alpha . A22 * T4 in C11 503 MMH_t H4(F, WH.recLevel-1, WH.Amin, WH.Amax, 2*WH.Bmin-2*WH.Bmax, 2*WH.Bmax-2*WH.Bmin, 0, 0); 504 505 fgemm (F, ta, tb, mr, nr, kr, alpha, A22, lda, X2, ldX2, F.zero, C11, ldc, H4); 506 507 fflas_delete (X2); 508 509 // U6 = U3 - P4 in C21 510 DFElt U6Min, U6Max; 511 // This test will be optimized out 512 if (Protected::NeedPreSubReduction (U6Min,U6Max, U3Min, U3Max, H4.Outmin,H4.Outmax, WH) ){ 513 freduce (F, mr, nr, C11, ldc); 514 freduce (F, mr, nr, C21, ldc); 515 } 516 fsubin(DF,mr,nr,(DFCEptr)C11,ldc,(DFEptr)C21,ldc); 517 518 // P2 = alpha . A12 * B21 in C11 519 MMH_t H2(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0); 520 521 fgemm (F, ta, tb, mr, nr, kr, alpha, A12, lda, B21, ldb, F.zero, C11, ldc, H2); 522 523 // U1 = P2 + P1 in C11 524 DFElt U1Min, U1Max; 525 // This test will be optimized out 526 if (Protected::NeedPreAddReduction (U1Min, U1Max, H1.Outmin, H1.Outmax, H2.Outmin,H2.Outmax, WH) ){ 527 freduce (F, mr, nr, X1, nr); 528 freduce (F, mr, nr, C11, ldc); 529 } 530 faddin(DF,mr,nr,(DFCEptr)X1,nr,(DFEptr)C11,ldc); 531 532 fflas_delete (X1); 533 534 WH.Outmin = std::min (U1Min, std::min (U5Min, std::min (U6Min, U7Min))); 535 WH.Outmax = std::max (U1Max, std::max (U5Max, std::max (U6Max, U7Max))); 536 537 } // Winograd 538 539} // BLAS3 540 541 542} // FFLAS 543 544#endif // __FFLASFFPACK_fgemm_winograd_INL 545 546/* -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ 547// vim:sts=4:sw=4:ts=4:et:sr:cino=>s,f0,{0,g0,(0,\:0,t0,+0,=s 548