1/* fflas/fflas_pfgemm.inl
2 * Copyright (C) 2013 Jean Guillaume Dumas Clement Pernet Ziad Sultan
3 *<ziad.sultan@imag.fr>
4 *
5 * ========LICENCE========
6 * This file is part of the library FFLAS-FFPACK.
7 *
8 * FFLAS-FFPACK is free software: you can redistribute it and/or modify
9 * it under the terms of the  GNU Lesser General Public
10 * License as published by the Free Software Foundation; either
11 * version 2.1 of the License, or (at your option) any later version.
12 *
13 * This library is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16 * Lesser General Public License for more details.
17 *
18 * You should have received a copy of the GNU Lesser General Public
19 * License along with this library; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21 * ========LICENCE========
22 *.
23 */
24
25
26
27namespace FFLAS
28{
29
30
31    template<class Field, class AlgoT, class FieldTrait>
32    typename Field::Element*
33    pfgemm(const Field& F,
34           const FFLAS_TRANSPOSE ta,
35           const FFLAS_TRANSPOSE tb,
36           const size_t m,
37           const size_t n,
38           const size_t k,
39           const typename Field::Element alpha,
40           const typename Field::ConstElement_ptr A, const size_t lda,
41           const typename Field::ConstElement_ptr B, const size_t ldb,
42           const typename Field::Element beta,
43           typename Field::Element * C, const size_t ldc,
44           MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Block, StrategyParameter::Threads> > & H){
45        {
46            H.parseq.set_numthreads( std::min(H.parseq.numthreads(), std::max((size_t)1,(size_t)(m*n/(__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD)))) );
47
48            MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Sequential> SeqH (H);
49            size_t sa = (ta==FFLAS::FflasNoTrans)?lda:1;
50            size_t sb = (tb==FFLAS::FflasNoTrans)?1:ldb;
51            SYNCH_GROUP({FORBLOCK2D(iter,m,n,H.parseq,
52                                    TASK( MODE(
53                                               READ(A[iter.ibegin()*sa],B[iter.jbegin()*sb])
54                                               CONSTREFERENCE(F, SeqH)
55                                               READWRITE(C[iter.ibegin()*ldc+iter.jbegin()])),
56                                          fgemm( F, ta, tb, iter.iend()-iter.ibegin(), iter.jend()-iter.jbegin(), k, alpha, A+iter.ibegin()*sa, lda, B+iter.jbegin()*sb, ldb, beta, C+iter.ibegin()*ldc+iter.jbegin(), ldc, SeqH););
57                                   );
58                        });
59        }
60        return C;
61
62
63    }
64
65    template<class Field, class AlgoT, class FieldTrait>
66    typename Field::Element*
67    pfgemm(const Field& F,
68           const FFLAS_TRANSPOSE ta,
69           const FFLAS_TRANSPOSE tb,
70           const size_t m,
71           const size_t n,
72           const size_t k,
73           const typename Field::Element alpha,
74           const typename Field::ConstElement_ptr AA, const size_t lda,
75           const typename Field::ConstElement_ptr BB, const size_t ldb,
76           const typename Field::Element beta,
77           typename Field::Element * C, const size_t ldc,
78           MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::ThreeDAdaptive> > & H){
79
80        typename Field::Element a = alpha;
81        typename Field::Element b = beta;
82        typename Field::ConstElement_ptr B = BB;
83        typename Field::ConstElement_ptr A = AA;
84        if (!m || !n) {return C;}
85        if (!k || F.isZero (alpha)){
86            fscalin(F, m, n, beta, C, ldc);
87            return C;
88        }
89
90        if (H.parseq.numthreads()<=1 || std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){
91            MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H);
92            return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH);
93        }
94
95        typedef MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::ThreeDAdaptive> > MMH_t;
96        MMH_t H1(H);
97        MMH_t H2(H);
98        if(__FFLASFFPACK_DIMKPENALTY*m > k && m >= n) {
99            SYNCH_GROUP(size_t M2= m>>1;
100                        H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1);
101                        H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads());
102
103                        typename Field::ConstElement_ptr A1= A;
104                        typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda);
105                        typename Field::Element_ptr C1= C;
106                        typename Field::Element_ptr C2= C+M2*ldc;
107
108                        // 2 multiply (1 split on dimension m)
109
110                        TASK(MODE(CONSTREFERENCE(F, H1) READ(A1,B) READWRITE(C1)),
111                             {pfgemm( F, ta, tb, M2, n, k, alpha, A1, lda, B, ldb, beta, C1, ldc, H1);}
112                            );
113
114                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A2,B) READWRITE(C2)),
115                             {pfgemm(F, ta, tb, m-M2, n, k, alpha, A2, lda, B, ldb, beta, C2, ldc, H2);}
116                            );
117                       );
118
119        } else if (__FFLASFFPACK_DIMKPENALTY*n > k) {
120            SYNCH_GROUP(
121                        size_t N2 = n>>1;
122                        H1.parseq.set_numthreads( H1.parseq.numthreads() >> 1);
123                        H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads());
124                        typename Field::ConstElement_ptr B1= B;
125                        typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1);
126
127                        typename Field::Element_ptr C1= C;
128                        typename Field::Element_ptr C2= C+N2;
129
130                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A,B1) READWRITE(C1)), pfgemm(F, ta, tb, m, N2, k, a, A, lda, B1, ldb, b, C1, ldc, H1));
131                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A,B2) READWRITE(C2)), pfgemm(F, ta, tb, m, n-N2, k, a, A, lda, B2, ldb, b,C2, ldc, H2));
132                       );
133
134        } else {
135
136            size_t K2 = k>>1;
137
138            typename Field::ConstElement_ptr B1= B;
139            typename Field::ConstElement_ptr B2= B+K2*((tb==FFLAS::FflasTrans)?1:ldb);
140            typename Field::ConstElement_ptr A1= A;
141            typename Field::ConstElement_ptr A2= A+K2*((ta==FFLAS::FflasTrans)?lda:1);
142            typename Field::Element_ptr C2 = fflas_new (F, m, n,Alignment::CACHE_PAGESIZE);
143
144            H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1);
145            H2.parseq.set_numthreads(H.parseq.numthreads()-H1.parseq.numthreads());
146            SYNCH_GROUP(
147                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A1,B1) READWRITE(C)), pfgemm(F, ta, tb, m, n, K2, a, A1, lda, B1, ldb, b, C, ldc, H1));
148
149                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A2,B2) READWRITE(C2)), pfgemm(F, ta, tb, m, n, k-K2, a, A2, lda, B2, ldb, F.zero, C2, n, H2));
150                        CHECK_DEPENDENCIES;
151
152                        TASK(MODE(CONSTREFERENCE(F) READ(C2) READWRITE(C)),faddin(F, n, m, C2, n, C, ldc));
153
154                       );
155            fflas_delete(C2);
156        }
157
158        return C;
159    }
160
161    template<class Field, class AlgoT, class FieldTrait>
162    typename Field::Element*
163    pfgemm (const Field& F,
164            const FFLAS_TRANSPOSE ta,
165            const FFLAS_TRANSPOSE tb,
166            const size_t m,
167            const size_t n,
168            const size_t k,
169            const typename Field::Element alpha,
170            const typename Field::ConstElement_ptr AA, const size_t lda,
171            const typename Field::ConstElement_ptr BB, const size_t ldb,
172            const typename Field::Element beta,
173            typename Field::Element * C, const size_t ldc,
174            MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoDAdaptive> > & H){
175
176        typename Field::Element a = alpha;
177        typename Field::Element b = beta;
178        typename Field::ConstElement_ptr B = BB;
179        typename Field::ConstElement_ptr A = AA;
180        if (!m || !n) {return C;}
181        if (!k || F.isZero (alpha)){
182            fscalin(F, m, n, beta, C, ldc);
183            return C;
184        }
185        if (H.parseq.numthreads()<=1 || m*n<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){
186            MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H);
187            return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH);
188
189        }
190        typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive, StrategyParameter::TwoDAdaptive> > MMH_t;
191        MMH_t H1(H);
192        MMH_t H2(H);
193        H1.parseq.set_numthreads(H1.parseq.numthreads() >> 1);
194        H2.parseq.set_numthreads(H.parseq.numthreads() - H1.parseq.numthreads());
195        if(m >= n) {
196            size_t M2= m>>1;
197            typename Field::ConstElement_ptr A1= A;
198            typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda);
199            typename Field::Element_ptr C1= C;
200            typename Field::Element_ptr C2= C+M2*ldc;
201            SYNCH_GROUP(
202                        TASK(MODE(CONSTREFERENCE(F,H1, A1, B) READ(M2, A1[0],B[0]) READWRITE(C1[0])), pfgemm(F, ta, tb, M2, n, k, alpha, A1, lda, B, ldb, beta, C1, ldc, H1));
203                        TASK(MODE(CONSTREFERENCE(F,H2, A2, B) READ(M2, A2[0],B[0]) READWRITE(C2[0])), pfgemm(F, ta, tb, m-M2, n, k, alpha, A2, lda, B, ldb, beta, C2, ldc, H2));
204
205                       );
206
207        } else {
208            size_t N2 = n>>1;
209            typename Field::ConstElement_ptr B1= B;
210            typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1);
211            typename Field::Element_ptr C1= C;
212            typename Field::Element_ptr C2= C+N2;
213            SYNCH_GROUP(
214                        TASK(MODE(CONSTREFERENCE(F,H1, A, B1) READ(N2, A[0], B1[0]) READWRITE(C1[0])), pfgemm(F, ta, tb, m, N2, k, a, A, lda, B1, ldb, b, C1, ldc, H1));
215                        TASK(MODE(CONSTREFERENCE(F,H2, A, B2) READ(N2, A[0], B2[0]) READWRITE(C2[0])), pfgemm(F, ta, tb, m, n-N2, k, a, A, lda, B2, ldb, b,C2, ldc, H2));
216                       );
217        }
218        return C;
219    }
220
221    template<class Field, class AlgoT, class FieldTrait>
222    typename Field::Element*
223    pfgemm( const Field& F,
224            const FFLAS_TRANSPOSE ta,
225            const FFLAS_TRANSPOSE tb,
226            const size_t m,
227            const size_t n,
228            const size_t k,
229            const typename Field::Element alpha,
230            const typename Field::ConstElement_ptr AA, const size_t lda,
231            const typename Field::ConstElement_ptr BB, const size_t ldb,
232            const typename Field::Element beta,
233            typename Field::Element * C, const size_t ldc,
234            MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoD> > & H){
235
236        typename Field::Element a = alpha;
237        typename Field::Element b = beta;
238        typename Field::ConstElement_ptr B = BB;
239        typename Field::ConstElement_ptr A = AA;
240        if (!m || !n) {return C;}
241        if (!k || F.isZero (alpha)){
242            fscalin(F, m, n, beta, C, ldc);
243            return C;
244        }
245
246        if(H.parseq.numthreads()<=1|| m*n<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){
247            MMHelper<Field,AlgoT,FieldTrait,ParSeqHelper::Sequential> SeqH(H);
248            return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, SeqH);
249        } else
250        {
251            size_t M2= m>>1;
252            size_t N2= n>>1;
253
254            typename Field::ConstElement_ptr A1= A;
255            typename Field::ConstElement_ptr A2= A+M2*((ta==FFLAS::FflasTrans)?1:lda);
256            typename Field::ConstElement_ptr B1= B;
257            typename Field::ConstElement_ptr B2= B+N2*((tb==FFLAS::FflasTrans)?ldb:1);
258
259            typename Field::Element_ptr C11= C;
260            typename Field::Element_ptr C21= C+M2*ldc;
261            typename Field::Element_ptr C12= C+N2;
262            typename Field::Element_ptr C22= C+N2+M2*ldc;
263
264            typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::TwoD> > MMH_t;
265            MMH_t H1(H);
266            MMH_t H2(H);
267            MMH_t H3(H);
268            MMH_t H4(H);
269            size_t nt = H.parseq.numthreads();
270            size_t nt_rec = nt/4;
271            size_t nt_mod = nt%4;
272            H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
273            H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
274            H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
275            H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
276            SYNCH_GROUP(
277                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A1,B1) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, k, alpha, A1, lda, B1, ldb, beta, C11, ldc, H1));
278
279                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A1,B2) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k, alpha, A1, lda, B2, ldb, beta, C12, ldc, H2));
280
281                        TASK(MODE(CONSTREFERENCE(F,H3) READ(A2,B1) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k, a, A2, lda, B1, ldb, b, C21, ldc, H3));
282
283                        TASK(MODE(CONSTREFERENCE(F,H4) READ(A2,B2) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, k, a, A2, lda, B2, ldb, b,C22, ldc, H4));
284                       );
285        }
286        return C;
287    }
288
289
290
291    template<class Field, class AlgoT, class FieldTrait>
292    typename Field::Element_ptr
293    pfgemm(const Field& F,
294           const FFLAS_TRANSPOSE ta,
295           const FFLAS_TRANSPOSE tb,
296           const size_t m,
297           const size_t n,
298           const size_t k,
299           const typename Field::Element alpha,
300           const typename Field::ConstElement_ptr A, const size_t lda,
301           const typename Field::ConstElement_ptr B, const size_t ldb,
302           const typename Field::Element beta,
303           typename Field::Element_ptr C, const size_t ldc,
304           MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeD> > & H){
305
306
307        if (!m || !n) {return C;}
308        if (!k || F.isZero (alpha)){
309            fscalin(F, m, n, beta, C, ldc);
310            return C;
311        }
312        if(H.parseq.numthreads() <= 1|| std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){
313            FFLAS::MMHelper<Field, AlgoT, FieldTrait,FFLAS::ParSeqHelper::Sequential> WH (H);
314            return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, WH);
315        }
316        else
317        {
318            typename Field::Element a = alpha;
319            typename Field::Element b = 0;
320
321            size_t M2= m>>1;
322            size_t N2= n>>1;
323            size_t K2= k>>1;
324            typename Field::ConstElement_ptr A11= A;
325            typename Field::ConstElement_ptr A12= A+K2*((ta==FFLAS::FflasTrans)?lda:1);
326            typename Field::ConstElement_ptr A21= A+M2*((ta==FFLAS::FflasTrans)?1:lda);
327            typename Field::ConstElement_ptr A22= A12+M2*((ta==FFLAS::FflasTrans)?1:lda);
328
329            typename Field::ConstElement_ptr B11= B;
330            typename Field::ConstElement_ptr B12= B+N2*((tb==FFLAS::FflasTrans)?ldb:1);
331            typename Field::ConstElement_ptr B21= B+K2*((tb==FFLAS::FflasTrans)?1:ldb);
332            typename Field::ConstElement_ptr B22= B12+K2*((tb==FFLAS::FflasTrans)?1:ldb);
333
334            typename Field::Element_ptr C11= C;
335            typename Field::Element_ptr C_11 = fflas_new (F, M2, N2,Alignment::CACHE_PAGESIZE);
336
337            typename Field::Element_ptr C12= C+N2;
338            typename Field::Element_ptr C_12 = fflas_new (F, M2, n-N2,Alignment::CACHE_PAGESIZE);
339
340            typename Field::Element_ptr C21= C+M2*ldc;
341            typename Field::Element_ptr C_21 = fflas_new (F, m-M2, N2,Alignment::CACHE_PAGESIZE);
342
343            typename Field::Element_ptr C22= C+N2+M2*ldc;
344            typename Field::Element_ptr C_22 = fflas_new (F, m-M2, n-N2,Alignment::CACHE_PAGESIZE);
345
346            // 1/ 8 multiply in parallel
347            //omp_set_task_affinity(omp_get_locality_domain_num_for( C11));
348
349            typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeD> > MMH_t;
350            MMH_t H1(H);
351            MMH_t H2(H);
352            MMH_t H3(H);
353            MMH_t H4(H);
354            MMH_t H5(H);
355            MMH_t H6(H);
356            MMH_t H7(H);
357            MMH_t H8(H);
358            size_t nt = H.parseq.numthreads();
359            size_t nt_rec = nt/8;
360            size_t nt_mod = nt % 8 ;
361            H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
362            H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
363            H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
364            H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
365            H5.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
366            H6.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
367            H7.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
368            H8.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
369
370            SYNCH_GROUP(
371                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A11,B11) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, K2, alpha, A11, lda, B11, ldb, beta, C11, ldc, H1));
372                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C_11));
373                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A12,B21) WRITE(C_11)), pfgemm(F, ta, tb, M2, N2, k-K2, a, A12, lda, B21, ldb, b,C_11, N2, H2));
374                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C12));
375                        TASK(MODE(CONSTREFERENCE(F,H3) READ(A12,B22) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k-K2, alpha, A12, lda, B22, ldb, beta, C12, ldc, H3));
376                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C_12));
377                        TASK(MODE(CONSTREFERENCE(F,H4) READ(A11,B12) WRITE(C_12)), pfgemm(F, ta, tb, M2, n-N2, K2, a, A11, lda, B12, ldb, b, C_12, n-N2, H4));
378                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C21));
379                        TASK(MODE(CONSTREFERENCE(F,H5) READ(A22,B21) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k-K2, alpha, A22, lda, B21, ldb, beta, C21, ldc, H5));
380                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C_21));
381                        TASK(MODE(CONSTREFERENCE(F,H6) READ(A21,B11) WRITE(C_21)), pfgemm(F, ta, tb, m-M2, N2, K2, a, A21, lda, B11, ldb, b,C_21, N2, H6));
382                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C22));
383                        TASK(MODE(CONSTREFERENCE(F,H7) READ(A21,B12) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, K2, alpha, A21, lda, B12, ldb, beta, C22, ldc, H7));
384                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C_22));
385                        TASK(MODE(CONSTREFERENCE(F,H8) READ(A22,B22) WRITE(C_22)), pfgemm(F, ta, tb, m-M2, n-N2, k-K2, a, A22, lda, B22, ldb, b,C_22, n-N2, H8));
386
387                        CHECK_DEPENDENCIES;
388                        // 2/ final add
389                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C11));
390                        TASK(MODE(CONSTREFERENCE(F) READ(C_11) READWRITE(C11)), faddin(F, M2, N2, C_11, N2, C11, ldc));
391                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C12));
392                        TASK(MODE(CONSTREFERENCE(F) READ(C_12) READWRITE(C12)),faddin(F, M2, n-N2, C_12, n-N2, C12, ldc));
393                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C21));
394                        TASK(MODE(CONSTREFERENCE(F) READ(C_21) READWRITE(C21)), faddin(F, m-M2, N2, C_21, N2, C21, ldc));
395                        //omp_set_task_affinity(omp_get_locality_domain_num_for( C22));
396                        TASK(MODE(CONSTREFERENCE(F) READ(C_22) READWRITE(C22)), faddin(F, m-M2, n-N2, C_22, n-N2, C22, ldc));
397
398                        );
399                        FFLAS::fflas_delete (C_11);
400                        FFLAS::fflas_delete (C_12);
401                        FFLAS::fflas_delete (C_21);
402                        FFLAS::fflas_delete (C_22);
403        }
404        return C;
405    }
406
407    template<class Field, class AlgoT, class FieldTrait>
408    typename Field::Element*
409    pfgemm( const Field& F,
410            const FFLAS_TRANSPOSE ta,
411            const FFLAS_TRANSPOSE tb,
412            const size_t m,
413            const size_t n,
414            const size_t k,
415            const typename Field::Element alpha,
416            const typename Field::ConstElement_ptr A, const size_t lda,
417            const typename Field::ConstElement_ptr B, const size_t ldb,
418            const typename Field::Element beta,
419            typename Field::Element_ptr C, const size_t ldc,
420            MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeDInPlace> > & H){
421
422
423        if (!m || !n) {return C;}
424        if (!k || F.isZero (alpha)){
425            fscalin(F, m, n, beta, C, ldc);
426            return C;
427        }
428
429        if(H.parseq.numthreads() <= 1|| std::min(m*n,std::min(m*k,k*n))<=__FFLASFFPACK_SEQPARTHRESHOLD*__FFLASFFPACK_SEQPARTHRESHOLD){	// threshold
430            FFLAS::MMHelper<Field, AlgoT, FieldTrait,FFLAS::ParSeqHelper::Sequential> WH (H);
431            return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, WH);
432        }else{
433            size_t M2= m>>1;
434            size_t N2= n>>1;
435            size_t K2= k>>1;
436            typename Field::ConstElement_ptr A11= A;
437            typename Field::ConstElement_ptr A12= A+K2*((ta==FFLAS::FflasTrans)?lda:1);
438            typename Field::ConstElement_ptr A21= A+M2*((ta==FFLAS::FflasTrans)?1:lda);
439            typename Field::ConstElement_ptr A22= A12+M2*((ta==FFLAS::FflasTrans)?1:lda);
440
441            typename Field::ConstElement_ptr B11= B;
442            typename Field::ConstElement_ptr B12= B+N2*((tb==FFLAS::FflasTrans)?ldb:1);
443            typename Field::ConstElement_ptr B21= B+K2*((tb==FFLAS::FflasTrans)?1:ldb);
444            typename Field::ConstElement_ptr B22= B12+K2*((tb==FFLAS::FflasTrans)?1:ldb);
445
446
447            typename Field::Element_ptr C11= C;
448            typename Field::Element_ptr C12= C+N2;
449            typename Field::Element_ptr C21= C+M2*ldc;
450            typename Field::Element_ptr C22= C+N2+M2*ldc;
451            typedef MMHelper<Field, AlgoT, FieldTrait, ParSeqHelper::Parallel<CuttingStrategy::Recursive,StrategyParameter::ThreeDInPlace> > MMH_t;
452            MMH_t H1(H);
453            MMH_t H2(H);
454            MMH_t H3(H);
455            MMH_t H4(H);
456            size_t nt = H.parseq.numthreads();
457            size_t nt_rec = nt/4;
458            size_t nt_mod = nt%4;
459            H1.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
460            H2.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
461            H3.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
462            H4.parseq.set_numthreads(std::max(size_t(1),nt_rec + ((nt_mod-- > 0)?1:0)));
463            SYNCH_GROUP(
464                        // 1/ 4 multiply
465                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A11,B11) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, K2, alpha, A11, lda, B11, ldb, beta, C11, ldc, H1));
466                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A12,B22) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, k-K2, alpha, A12, lda, B22, ldb, beta, C12, ldc, H2));
467                        TASK(MODE(CONSTREFERENCE(F,H3) READ(A22,B21) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, k-K2, alpha, A22, lda, B21, ldb, beta, C21, ldc, H3));
468                        TASK(MODE(CONSTREFERENCE(F,H4) READ(A21,B12) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, K2, alpha, A21, lda, B12, ldb, beta, C22, ldc, H4));
469
470                        CHECK_DEPENDENCIES;
471                        // 2/ 4 add+multiply
472                        TASK(MODE(CONSTREFERENCE(F,H1) READ(A12,B21) READWRITE(C11)), pfgemm(F, ta, tb, M2, N2, k-K2, alpha, A12, lda, B21, ldb, F.one, C11, ldc, H1));
473                        TASK(MODE(CONSTREFERENCE(F,H2) READ(A11,B12) READWRITE(C12)), pfgemm(F, ta, tb, M2, n-N2, K2, alpha, A11, lda, B12, ldb, F.one, C12, ldc, H2));
474                        TASK(MODE(CONSTREFERENCE(F,H3) READ(A21,B11) READWRITE(C21)), pfgemm(F, ta, tb, m-M2, N2, K2, alpha, A21, lda, B11, ldb, F.one, C21, ldc, H3));
475                        TASK(MODE(CONSTREFERENCE(F,H4) READ(A22,B22) READWRITE(C22)), pfgemm(F, ta, tb, m-M2, n-N2, k-K2, alpha, A22, lda, B22, ldb, F.one, C22, ldc, H4));
476                       );
477        }
478        return C;
479    }
480
481
482
483} // FFLAS
484/* -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */
485// vim:sts=4:sw=4:ts=4:et:sr:cino=>s,f0,{0,g0,(0,\:0,t0,+0,=s
486