1/*
2 * Copyright (C) 2014 the LinBox group
3 *
4 * Written by Clement Pernet <Clement.Pernet@imag.fr>
5 *            Brice Boyer (briceboyer) <boyer.brice@gmail.com>
6 *            Ziad Sultan <ziad.sultan@imag.fr>
7 *
8 * ========LICENCE========
9 * This file is part of the library FFLAS-FFPACK.
10 *
11 * FFLAS-FFPACK is free software: you can redistribute it and/or modify
12 * it under the terms of the  GNU Lesser General Public
13 * License as published by the Free Software Foundation; either
14 * version 2.1 of the License, or (at your option) any later version.
15 *
16 * This library is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19 * Lesser General Public License for more details.
20 *
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with this library; if not, write to the Free Software
23 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
24 * ========LICENCE========
25 *.
26 */
27
28/** @file fflas/fflas_fgemm/winograd.inl
29 * @ingroup MMalgos
30 * @brief Winograd implementation
31 * @bib ISSAC09 Scheduling
32 */
33
34#ifndef __FFLASFFPACK_fgemm_winograd_INL
35#define __FFLASFFPACK_fgemm_winograd_INL
36
37namespace FFLAS { namespace BLAS3 {
38
39    template < class Field, class FieldTrait, class Strat, class Param >
40    inline typename Field::Element_ptr
41    WinoPar (const Field& F,
42             const FFLAS_TRANSPOSE ta,
43             const FFLAS_TRANSPOSE tb,
44             const size_t mr, const size_t nr, const size_t kr,
45             const typename Field::Element alpha,
46             typename Field::ConstElement_ptr A,const size_t lda,
47             typename Field::ConstElement_ptr B,const size_t ldb,
48             const typename Field::Element  beta,
49             typename Field::Element_ptr C, const size_t ldc,
50             // const size_t kmax, const size_t w, const FFLAS_BASE base
51             MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait, ParSeqHelper::Parallel<Strat,Param> > & WH
52            )
53    {
54        FFLASFFPACK_check(F.isZero(beta));
55
56        //			typedef MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait > MMH_t;
57        typedef MMHelper<Field, MMHelperAlgo::WinogradPar, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoDAdaptive> > MMH_t;
58        const typename MMH_t::DelayedField & DF = WH.delayedField;
59        typedef typename  MMH_t::DelayedField::Element DFElt;
60
61        size_t lb, cb, la, ca, ldX2;
62        // size_t x3rd = std::max(mr,kr);
63        typename Field::ConstElement_ptr A11=A, A12, A21, A22;
64        typename Field::ConstElement_ptr B11=B, B12, B21, B22;
65        typename Field::Element_ptr C11=C, C12=C+nr, C21=C+mr*ldc, C22=C21+nr;
66
67        size_t x1rd = std::max(nr,kr);
68        size_t ldX1;
69        if (ta == FflasTrans) {
70            A21 = A + mr;
71            A12 = A + kr*lda;
72            A22 = A12 + mr;
73            la = kr;
74            ca = mr;
75            ldX1 = mr;
76        } else {
77            A12 = A + kr;
78            A21 = A + mr*lda;
79            A22 = A21 + kr;
80            la = mr;
81            ca = kr;
82            ldX1  = x1rd;
83        }
84        if (tb == FflasTrans) {
85            B21 = B + kr;
86            B12 = B + nr*ldb;
87            B22 = B12 + kr;
88            lb = nr;
89            cb = kr;
90            ldX2 = kr;
91        } else {
92            B12 = B + nr;
93            B21 = B + kr*ldb;
94            B22 = B21 + nr;
95            lb = kr;
96            ldX2 = cb = nr;
97        }
98
99        // 11 temporary submatrices are required
100        typename Field::Element_ptr X21 = fflas_new (F, kr, nr);
101        typename Field::Element_ptr X11 = fflas_new (F,mr,x1rd);
102
103        typename Field::Element_ptr X22 = fflas_new (F, kr, nr);
104        typename Field::Element_ptr X12 = fflas_new (F,mr,x1rd);
105
106        typename Field::Element_ptr X23 = fflas_new (F, kr, nr);
107        typename Field::Element_ptr X13 = fflas_new (F,mr,x1rd);
108
109        typename Field::Element_ptr X24 = fflas_new (F, kr, nr);
110        typename Field::Element_ptr X14 = fflas_new (F,mr,x1rd);
111        typename Field::Element_ptr X15 = fflas_new (F,mr,x1rd);
112
113        typename Field::Element_ptr C_11 = fflas_new (F,mr,nr);
114        typename Field::Element_ptr CC_11 = fflas_new (F,mr,nr);
115        SYNCH_GROUP(
116
117                    // T3 = B22 - B12 in X21  and S3 = A11 - A21 in X11
118                    TASK(MODE(READ(B22, B12) WRITE(X21) CONSTREFERENCE(DF)),
119                         pfsub(DF,lb,cb,B22,ldb,B12,ldb,X21,ldX2, NUM_THREADS););
120                    TASK(MODE(READ(A11, A21) WRITE(X11) CONSTREFERENCE(DF)),
121                         pfsub(DF,la,ca,A11,lda,A21,lda,X11,ldX1, NUM_THREADS););
122
123                    // T1 = B12 - B11 in X22 and  S1 = A21 + A22 in X12
124                    TASK(MODE(READ(B11, B12) WRITE(X22) CONSTREFERENCE(DF)),
125                         pfsub(DF,lb,cb,B12,ldb,B11,ldb,X22,ldX2, NUM_THREADS););
126                    TASK(MODE(READ(A12, A22) WRITE(X12) CONSTREFERENCE(DF)),
127                         pfadd(DF,la,ca,A21,lda,A22,lda,X12,ldX1, NUM_THREADS););
128
129                    CHECK_DEPENDENCIES;
130
131                    // T2 = B22 - T1 in X23 and  S2 = S1 - A11 in X13
132                    TASK(MODE(READ(B22, X22) READWRITE(X23) CONSTREFERENCE(DF)),
133                         pfsub(DF,lb,cb,B22,ldb,X22,ldX2,X23,ldX2, NUM_THREADS););
134                    TASK(MODE(READ(A11, X12) READWRITE(X13) CONSTREFERENCE(DF)),
135                         //		     fsub(DF,la,ca,A11,lda,X12,ldX1,X13,ldX1););
136                    pfsub(DF,la,ca,X12,ldX1,A11,lda,X13,ldX1, NUM_THREADS););
137                    /*
138                       fsub(DF,lb,cb,B22,ldb,X2,ldX2,X2,ldX2);
139                       fsubin(DF,la,ca,A11,lda,X1,ldX1););
140                       */
141                    CHECK_DEPENDENCIES;
142
143                    // T4 = T2 - B21 in X2 and S4 = A12 -S2 in X1
144                    TASK(MODE(READ(B21, X23) READWRITE(X24) CONSTREFERENCE(DF)),
145                         //		     fsub(DF,lb,cb,B21,ldb,X23,ldX2,X24,ldX2);
146                         pfsub(DF,lb,cb,X23,ldX2,B21,ldb,X24,ldX2, NUM_THREADS););
147                    TASK(MODE(READ(A12, X13) READWRITE(X14) CONSTREFERENCE(DF)),
148                         pfsub(DF,la,ca,A12,lda,X13,ldX1,X14,ldX1, NUM_THREADS););
149
150                    /*
151                       fsubin(DF,lb,cb,B21,ldb,X2,ldX2);
152                       fsub(DF,la,ca,A12,lda,X1,ldX1,X1,ldX1););
153                       */
154                    CHECK_DEPENDENCIES;
155
156                    // P1 = alpha . A11 * B11 in X1
157
158                    MMH_t H1(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0);
159                    MMH_t H7(F, WH.recLevel-1, -(WH.Amax-WH.Amin), WH.Amax-WH.Amin, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0,0);
160                    MMH_t H5(F, WH.recLevel-1, 2*WH.Amin, 2*WH.Amax, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0, 0);
161                    MMH_t H6(F, WH.recLevel-1, 2*WH.Amin-WH.Amax, 2*WH.Amax-WH.Amin, 2*WH.Bmin-WH.Bmax, 2*WH.Bmax-WH.Bmin, 0, 0);
162                    MMH_t H3(F, WH.recLevel-1, 2*WH.Amin-2*WH.Amax, 2*WH.Amax-2*WH.Amin, WH.Bmin, WH.Bmax, 0, 0);
163                    MMH_t H4(F, WH.recLevel-1, WH.Amin, WH.Amax, 2*WH.Bmin-2*WH.Bmax, 2*WH.Bmax-2*WH.Bmin, 0, 0);
164                    MMH_t H2(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0);
165
166                    size_t nt = WH.parseq.numthreads();
167                    size_t nt_rec = nt/7;
168                    size_t nt_mod = nt % 7 ;
169                    H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
170                    H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
171                    H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
172                    H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
173                    H5.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
174                    H6.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
175                    H7.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
176
177                    TASK(MODE(READ(A11, B11) WRITE(X15) CONSTREFERENCE(F,H1)),
178                         fgemm (F, ta, tb, mr, nr, kr, alpha, A11, lda, B11, ldb, F.zero, X15, x1rd, H1););
179                    // P7 = alpha . S3 * T3  in C21
180                    TASK(MODE(READ(X11, X21) WRITE(C21) CONSTREFERENCE(F,H7)),
181                         fgemm (F, ta, tb, mr, nr, kr, alpha, X11, ldX1, X21, ldX2, F.zero, C21, ldc, H7););
182
183                    // P5 = alpha . S1*T1 in C22
184                    TASK(MODE(READ(X12, X22) WRITE(C22) CONSTREFERENCE(F,H5)),
185                         fgemm (F, ta, tb, mr, nr, kr, alpha, X12, ldX1, X22, ldX2, F.zero, C22, ldc, H5););
186
187                    // P6 = alpha . S2 * T2 in C12
188                    TASK(MODE(READ(X13, X23) WRITE(C12) CONSTREFERENCE(F,H6)),
189                         fgemm (F, ta, tb, mr, nr, kr, alpha, X13, ldX1, X23, ldX2, F.zero, C12, ldc, H6););
190
191                    // P3 = alpha . S4*B22 in CC_11
192                    TASK(MODE(READ(X14, B22) WRITE(CC_11) CONSTREFERENCE(F,H3)),
193                         fgemm (F, ta, tb, mr, nr, kr, alpha, X14, ldX1, B22, ldb, F.zero, CC_11, nr, H3););
194
195                    // P4 = alpha . A22 * T4 in C_11
196                    TASK(MODE(READ(A22) WRITE(C_11) READWRITE(X24, X22, X23, X21) CONSTREFERENCE(F,H4)),
197                         fgemm (F, ta, tb, mr, nr, kr, alpha, A22, lda, X24, ldX2, F.zero, C_11, nr, H4);
198                        );
199
200                    // P2 = alpha . A12 * B21  in C11
201                    TASK(MODE(READ(A12, B21) WRITE(C11) CONSTREFERENCE(F,H2)),
202                         fgemm (F, ta, tb, mr, nr, kr, alpha, A12, lda, B21, ldb, F.zero, C11, ldc, H2););
203                    CHECK_DEPENDENCIES;
204
205                    DFElt U2Min, U2Max;
206                    DFElt U3Min, U3Max;
207                    DFElt U4Min, U4Max;
208                    DFElt U7Min, U7Max;
209                    DFElt U5Min, U5Max;
210                    // U2 = P1 + P6 in C12  and
211                    // U3 = P7 + U2 in C21  and
212                    // U4 = P5 + U2 in C12    and
213                    // U7 = P5 + U3 in C22    and
214                    // U5 = P3 + U4 in C12
215                    // BIG TASK with 5 Addin function calls
216                    //		TASK(MODE(READWRITE(X15, C12) CONSTREFERENCE(F, DF, WH, U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax)),
217                    if (Protected::NeedPreAddReduction(U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax, WH)){
218                        TASK(MODE(READWRITE(X15) CONSTREFERENCE(F)),
219                             pfreduce (F, mr, x1rd, X15, x1rd, NUM_THREADS);
220                            );
221                        TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)),
222                             pfreduce (F, mr, nr, C12, ldc, NUM_THREADS);
223                            );
224                        CHECK_DEPENDENCIES;
225                    }
226                    TASK(MODE(READWRITE(X15, C12) CONSTREFERENCE(DF)),
227                         pfaddin(DF,mr,nr,X15,x1rd,C12,ldc, NUM_THREADS);
228                        );
229                    CHECK_DEPENDENCIES;
230                    //		TASK(MODE(READWRITE(C12, C21) CONSTREFERENCE(F, DF, WH, U3Min, U3Max, U2Min, U2Max)),
231                    if (Protected::NeedPreAddReduction(U3Min, U3Max, U2Min, U2Max, H7.Outmin, H7.Outmax, WH)){
232                        TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)),
233                             pfreduce (F, mr, nr, C12, ldc, NUM_THREADS);
234                            );
235                        TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)),
236                             pfreduce (F, mr, nr, C21, ldc, NUM_THREADS);
237                            );
238                        CHECK_DEPENDENCIES;
239                    }
240                    TASK(MODE(READWRITE(C12, C21) CONSTREFERENCE(DF)),
241                         pfaddin(DF,mr,nr,C12,ldc,C21,ldc, NUM_THREADS);
242                        );
243                    CHECK_DEPENDENCIES;
244                    //		TASK(MODE(READWRITE(C12, C22) CONSTREFERENCE(F, DF, WH) VALUE(U4Min, U4Max, U2Min, U2Max)),
245                    if (Protected::NeedPreAddReduction(U4Min, U4Max, U2Min, U2Max, H5.Outmin, H5.Outmax, WH)){
246                        TASK(MODE(READWRITE(C22) CONSTREFERENCE(F)),
247                             pfreduce (F, mr, nr, C22, ldc, NUM_THREADS);
248                            );
249                        TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)),
250                             pfreduce (F, mr, nr, C12, ldc, NUM_THREADS);
251                            );
252                        CHECK_DEPENDENCIES;
253                    }
254                    TASK(MODE(READWRITE(C12, C22) CONSTREFERENCE(DF, WH)),
255                         pfaddin(DF,mr,nr,C22,ldc,C12,ldc, NUM_THREADS);
256                        );
257                    CHECK_DEPENDENCIES;
258                    //		TASK(MODE(READWRITE(C22, C21) CONSTREFERENCE(F, DF, WH) VALUE(U3Min, U3Max, U7Min, U7Max)),
259                    if (Protected::NeedPreAddReduction (U7Min,U7Max, U3Min, U3Max, H5.Outmin,H5.Outmax, WH) ){
260                        TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)),
261                             pfreduce (F, mr, nr, C21, ldc, NUM_THREADS);
262                            );
263                        TASK(MODE(READWRITE(C22) CONSTREFERENCE(F)),
264                             pfreduce (F, mr, nr, C22, ldc, NUM_THREADS);
265                            );
266                        CHECK_DEPENDENCIES;
267                    }
268                    TASK(MODE(READWRITE(C22, C21) CONSTREFERENCE(DF, WH)),
269                         pfaddin(DF,mr,nr,C21,ldc,C22,ldc, NUM_THREADS);
270                        );
271                    //		TASK(MODE(READWRITE(C12, CC_11) CONSTREFERENCE(F, DF, WH) VALUE(U5Min, U5Max, U4Min, U4Max)),
272                    if (Protected::NeedPreAddReduction (U5Min,U5Max, U4Min, U4Max, H3.Outmin, H3.Outmax, WH) ){
273                        TASK(MODE(READWRITE(C12) CONSTREFERENCE(F)),
274                             pfreduce (F, mr, nr, C12, ldc, NUM_THREADS);
275                            );
276                        TASK(MODE(READWRITE(CC_11) CONSTREFERENCE(F)),
277                             pfreduce (F, mr, nr, CC_11, nr, NUM_THREADS);
278                            );
279                        CHECK_DEPENDENCIES;
280                    }
281                    TASK(MODE(READWRITE(C12, CC_11) CONSTREFERENCE(DF, WH)),
282                         pfaddin(DF,mr,nr,CC_11,nr,C12,ldc, NUM_THREADS);
283                        );
284                    CHECK_DEPENDENCIES;
285
286                    // U6 = U3 - P4 in C21
287                    DFElt U6Min, U6Max;
288                    //		TASK(MODE(READWRITE(C_11, C21) CONSTREFERENCE(F, DF, WH) VALUE(U6Min, U6Max, U3Min, U3Max)),
289                    if (Protected::NeedPreSubReduction (U6Min,U6Max, U3Min, U3Max, H4.Outmin,H4.Outmax, WH) ){
290                        TASK(MODE(READWRITE(CC_11) CONSTREFERENCE(F)),
291                             pfreduce (F, mr, nr, C_11, nr, NUM_THREADS);
292                            );
293                        TASK(MODE(READWRITE(C21) CONSTREFERENCE(F)),
294                             pfreduce (F, mr, nr, C21, ldc, NUM_THREADS);
295                            );
296                        CHECK_DEPENDENCIES
297                    }
298                    TASK(MODE(READWRITE(C_11, C21) CONSTREFERENCE(DF, WH) ),
299                         pfsubin(DF,mr,nr,C_11,nr,C21,ldc, NUM_THREADS);
300                        );
301
302                    //CHECK_DEPENDENCIES;
303
304                    //  U1 = P2 + P1 in C11
305                    DFElt U1Min, U1Max;
306                    //		TASK(MODE(READWRITE(C11, X15/*, X14, X13, X12, X11*/) CONSTREFERENCE(F, DF, WH) VALUE(U1Min, U1Max)),
307                    if (Protected::NeedPreAddReduction (U1Min, U1Max, H1.Outmin, H1.Outmax, H2.Outmin,H2.Outmax, WH) ){
308                        TASK(MODE(READWRITE(X15) CONSTREFERENCE(F)),
309                             pfreduce (F, mr, nr, X15, x1rd, NUM_THREADS);
310                            );
311                        TASK(MODE(READWRITE(C11) CONSTREFERENCE(F)),
312                             pfreduce (F, mr, nr, C11, ldc, NUM_THREADS);
313                            );
314                        CHECK_DEPENDENCIES
315                    }
316                    TASK(MODE(READWRITE(C11, X15) CONSTREFERENCE(DF, WH)),
317                         pfaddin(DF,mr,nr,X15,x1rd,C11,ldc, NUM_THREADS);
318                        );
319
320                    WH.Outmin = std::min (U1Min, std::min (U5Min, std::min (U6Min, U7Min)));
321                    WH.Outmax = std::max (U1Max, std::max (U5Max, std::max (U6Max, U7Max)));
322
323                    );
324                    //			WAIT;
325
326
327                    fflas_delete (CC_11);
328                    fflas_delete (C_11);
329                    fflas_delete (X15);
330                    fflas_delete (X14);
331                    fflas_delete (X24);
332                    fflas_delete (X13);
333                    fflas_delete (X23);
334                    fflas_delete (X12);
335                    fflas_delete (X22);
336                    fflas_delete (X11);
337                    fflas_delete (X21);
338
339                    return C;
340    } //wino parallel
341
342
343    template < class Field, class FieldTrait >
344    inline void Winograd (const Field& F,
345                          const FFLAS_TRANSPOSE ta,
346                          const FFLAS_TRANSPOSE tb,
347                          const size_t mr, const size_t nr, const size_t kr,
348                          const typename Field::Element alpha,
349                          typename Field::ConstElement_ptr A,const size_t lda,
350                          typename Field::ConstElement_ptr B,const size_t ldb,
351                          const typename Field::Element  beta,
352                          typename Field::Element_ptr C, const size_t ldc,
353                          // const size_t kmax, const size_t w, const FFLAS_BASE base
354                          MMHelper<Field, MMHelperAlgo::Winograd, FieldTrait> & WH
355                         )
356    {
357        FFLASFFPACK_check(F.isZero(beta));
358
359        typedef MMHelper<Field, MMHelperAlgo::Winograd, FieldTrait > MMH_t;
360        typedef typename  MMH_t::DelayedField::Element_ptr DFEptr;
361        typedef typename  MMH_t::DelayedField::ConstElement_ptr DFCEptr;
362        typedef typename  MMH_t::DelayedField::Element DFElt;
363
364        const typename MMH_t::DelayedField & DF = WH.delayedField;
365
366        size_t lb, cb, la, ca, ldX2;
367        // size_t x3rd = std::max(mr,kr);
368        typename Field::ConstElement_ptr A11=A, A12, A21, A22;
369        typename Field::ConstElement_ptr B11=B, B12, B21, B22;
370        typename Field::Element_ptr C11=C, C12=C+nr, C21=C+mr*ldc, C22=C21+nr;
371
372        size_t x1rd = std::max(nr,kr);
373        size_t ldX1;
374        if (ta == FflasTrans) {
375            A21 = A + mr;
376            A12 = A + kr*lda;
377            A22 = A12 + mr;
378            la = kr;
379            ca = mr;
380            ldX1 = mr;
381        } else {
382            A12 = A + kr;
383            A21 = A + mr*lda;
384            A22 = A21 + kr;
385            la = mr;
386            ca = kr;
387            ldX1  = x1rd;
388        }
389        if (tb == FflasTrans) {
390            B21 = B + kr;
391            B12 = B + nr*ldb;
392            B22 = B12 + kr;
393            lb = nr;
394            cb = kr;
395            ldX2 = kr;
396        } else {
397            B12 = B + nr;
398            B21 = B + kr*ldb;
399            B22 = B21 + nr;
400            lb = kr;
401            ldX2 = cb = nr;
402        }
403        // Two temporary submatrices are required
404        typename Field::Element_ptr X2 = fflas_new (F, kr, nr);
405
406        // T3 = B22 - B12 in X2
407        fsub(DF,lb,cb, (DFCEptr) B22,ldb, (DFCEptr) B12,ldb, (DFEptr)X2,ldX2);
408
409        // S3 = A11 - A21 in X1
410        typename Field::Element_ptr X1 = fflas_new (F,mr,x1rd);
411        fsub(DF,la,ca,(DFCEptr)A11,lda,(DFCEptr)A21,lda,(DFEptr)X1,ldX1);
412
413        // P7 = alpha . S3 * T3  in C21
414        MMH_t H7(F, WH.recLevel-1, -(WH.Amax-WH.Amin), WH.Amax-WH.Amin, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0,0);
415
416        fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C21, ldc, H7);
417
418        // T1 = B12 - B11 in X2
419        fsub(DF,lb,cb,(DFCEptr)B12,ldb,(DFCEptr)B11,ldb,(DFEptr)X2,ldX2);
420
421        // S1 = A21 + A22 in X1
422        fadd(DF,la,ca,(DFCEptr)A21,lda,(DFCEptr)A22,lda,(DFEptr)X1,ldX1);
423
424        // P5 = alpha . S1*T1 in C22
425        MMH_t H5(F, WH.recLevel-1, 2*WH.Amin, 2*WH.Amax, -(WH.Bmax-WH.Bmin), WH.Bmax-WH.Bmin, 0, 0);
426
427        fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C22, ldc, H5);
428
429        // T2 = B22 - T1 in X2
430        fsub(DF,lb,cb,(DFCEptr)B22,ldb,(DFCEptr)X2,ldX2,(DFEptr)X2,ldX2);
431
432        // S2 = S1 - A11 in X1
433        fsubin(DF,la,ca,(DFCEptr)A11,lda,(DFEptr)X1,ldX1);
434
435        // P6 = alpha . S2 * T2 in C12
436        MMH_t H6(F, WH.recLevel-1, 2*WH.Amin-WH.Amax, 2*WH.Amax-WH.Amin, 2*WH.Bmin-WH.Bmax, 2*WH.Bmax-WH.Bmin, 0, 0);
437
438        fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, X2, ldX2, F.zero, C12, ldc, H6);
439
440        // S4 = A12 -S2 in X1
441        fsub(DF,la,ca,(DFCEptr)A12,lda,(DFCEptr)X1,ldX1,(DFEptr)X1,ldX1);
442
443        // P3 = alpha . S4*B22 in C11
444        MMH_t H3(F, WH.recLevel-1, 2*WH.Amin-2*WH.Amax, 2*WH.Amax-2*WH.Amin, WH.Bmin, WH.Bmax, 0, 0);
445
446        fgemm (F, ta, tb, mr, nr, kr, alpha, X1, ldX1, B22, ldb, F.zero, C11, ldc, H3);
447
448        // P1 = alpha . A11 * B11 in X1
449        MMH_t H1(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0);
450
451        fgemm (F, ta, tb, mr, nr, kr, alpha, A11, lda, B11, ldb, F.zero, X1, nr, H1);
452
453        // U2 = P1 + P6 in C12  and
454        DFElt U2Min, U2Max;
455        // This test will be optimized out
456        if (Protected::NeedPreAddReduction(U2Min, U2Max, H1.Outmin, H1.Outmax, H6.Outmin, H6.Outmax, WH)){
457            freduce (F, mr, nr, X1, nr);
458            freduce (F, mr, nr, C12, ldc);
459        }
460        faddin(DF,mr,nr,(DFCEptr)X1,nr,(DFEptr)C12,ldc);
461
462        // U3 = P7 + U2 in C21  and
463        DFElt U3Min, U3Max;
464        // This test will be optimized out
465        if (Protected::NeedPreAddReduction(U3Min, U3Max, U2Min, U2Max, H7.Outmin, H7.Outmax, WH)){
466            freduce (F, mr, nr, C12, ldc);
467            freduce (F, mr, nr, C21, ldc);
468        }
469        faddin(DF,mr,nr,(DFCEptr)C12,ldc,(DFEptr)C21,ldc);
470
471
472        // U4 = P5 + U2 in C12    and
473        DFElt U4Min, U4Max;
474        // This test will be optimized out
475        if (Protected::NeedPreAddReduction(U4Min, U4Max, U2Min, U2Max, H5.Outmin, H5.Outmax, WH)){
476            freduce (F, mr, nr, C22, ldc);
477            freduce (F, mr, nr, C12, ldc);
478        }
479        faddin(DF,mr,nr,(DFCEptr)C22,ldc,(DFEptr)C12,ldc);
480
481        // U7 = P5 + U3 in C22    and
482        DFElt U7Min, U7Max;
483        // This test will be optimized out
484        if (Protected::NeedPreAddReduction (U7Min,U7Max, U3Min, U3Max, H5.Outmin,H5.Outmax, WH) ){
485            freduce (F, mr, nr, C21, ldc);
486            freduce (F, mr, nr, C22, ldc);
487        }
488        faddin(DF,mr,nr,(DFCEptr)C21,ldc,(DFEptr)C22,ldc);
489
490        // U5 = P3 + U4 in C12
491        DFElt U5Min, U5Max;
492        // This test will be optimized out
493        if (Protected::NeedPreAddReduction (U5Min,U5Max, U4Min, U4Max, H3.Outmin, H3.Outmax, WH) ){
494            freduce (F, mr, nr, C12, ldc);
495            freduce (F, mr, nr, C11, ldc);
496        }
497        faddin(DF,mr,nr,(DFCEptr)C11,ldc,(DFEptr)C12,ldc);
498
499        // T4 = T2 - B21 in X2
500        fsubin(DF,lb,cb,(DFCEptr)B21,ldb,(DFEptr)X2,ldX2);
501
502        // P4 = alpha . A22 * T4 in C11
503        MMH_t H4(F, WH.recLevel-1, WH.Amin, WH.Amax, 2*WH.Bmin-2*WH.Bmax, 2*WH.Bmax-2*WH.Bmin, 0, 0);
504
505        fgemm (F, ta, tb, mr, nr, kr, alpha, A22, lda, X2, ldX2, F.zero, C11, ldc, H4);
506
507        fflas_delete (X2);
508
509        // U6 = U3 - P4 in C21
510        DFElt U6Min, U6Max;
511        // This test will be optimized out
512        if (Protected::NeedPreSubReduction (U6Min,U6Max, U3Min, U3Max, H4.Outmin,H4.Outmax, WH) ){
513            freduce (F, mr, nr, C11, ldc);
514            freduce (F, mr, nr, C21, ldc);
515        }
516        fsubin(DF,mr,nr,(DFCEptr)C11,ldc,(DFEptr)C21,ldc);
517
518        // P2 = alpha . A12 * B21  in C11
519        MMH_t H2(F, WH.recLevel-1, WH.Amin, WH.Amax, WH.Bmin, WH.Bmax, 0, 0);
520
521        fgemm (F, ta, tb, mr, nr, kr, alpha, A12, lda, B21, ldb, F.zero, C11, ldc, H2);
522
523        //  U1 = P2 + P1 in C11
524        DFElt U1Min, U1Max;
525        // This test will be optimized out
526        if (Protected::NeedPreAddReduction (U1Min, U1Max, H1.Outmin, H1.Outmax, H2.Outmin,H2.Outmax, WH) ){
527            freduce (F, mr, nr, X1, nr);
528            freduce (F, mr, nr, C11, ldc);
529        }
530        faddin(DF,mr,nr,(DFCEptr)X1,nr,(DFEptr)C11,ldc);
531
532        fflas_delete (X1);
533
534        WH.Outmin = std::min (U1Min, std::min (U5Min, std::min (U6Min, U7Min)));
535        WH.Outmax = std::max (U1Max, std::max (U5Max, std::max (U6Max, U7Max)));
536
537    } // Winograd
538
539} // BLAS3
540
541
542} // FFLAS
543
544#endif // __FFLASFFPACK_fgemm_winograd_INL
545
546/* -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
547// vim:sts=4:sw=4:ts=4:et:sr:cino=>s,f0,{0,g0,(0,\:0,t0,+0,=s
548