1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 
12 /* size variables, all const */
13 /* here we assume that input and output blocking is similar */
14 const int nBlocksIFm = handle->desc.C / handle->bc;
15 const int nBlocksOFm = handle->desc.K / handle->bk;
16 const int nBlocksMB  = handle->desc.N / handle->bn;
17 int lpb = 2;
18 const int bc_lp = handle->bc/lpb;
19 /* const int bc = handle->bc;*/
20 int use_2d_blocking = handle->fwd_2d_blocking;
21 
22 /* computing first logical thread */
23 const int ltid = tid - start_thread;
24 /* number of tasks that could be run in parallel */
25 const int work = nBlocksOFm * nBlocksMB;
26 /* compute chunk size */
27 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
28 /* compute thr_begin and thr_end */
29 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
30 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
31 
32 /* loop variables */
33 int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, ifm1 = 0;
34 int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0;
35 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
36 int mb2 = 0, ofm2 = 0;
37 #endif
38 LIBXSMM_VLA_DECL(4, element_output_type,       output,  (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk);
39 LIBXSMM_VLA_DECL(4, const element_input_type,  input,   (element_input_type* )handle->reg_input->data,  nBlocksIFm, handle->bn, handle->bc);
40 LIBXSMM_VLA_DECL(5, const element_filter_type, filter,  (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, handle->bk, lpb);
41 float* temp_output = (float*)handle->scratch;
42 LIBXSMM_VLA_DECL(4,        float,    output_f32, (float*) temp_output, nBlocksOFm,handle->bn,handle->bk);
43 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
44 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
45 LIBXSMM_VLA_DECL(2, const element_input_type,               bias,   (element_input_type*)              handle->reg_bias->data,                           handle->bk);
46 #endif
47 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
48 LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk);
49 LIBXSMM_VLA_DECL(4, __mmask16,  relubitmask,     (__mmask16*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/16);
50 #endif
51 #endif
52 unsigned long long  blocks = nBlocksIFm;
53 int CB_BLOCKS = nBlocksIFm, BF = 1;
54 
55 BF = handle->fwd_bf;
56 CB_BLOCKS = nBlocksIFm/BF;
57 blocks = CB_BLOCKS;
58 
59 if (use_2d_blocking == 1) {
60   row_teams = handle->fwd_row_teams;
61   column_teams = handle->fwd_column_teams;
62   my_col_id = ltid % column_teams;
63   my_row_id = ltid / column_teams;
64   im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams);
65   in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams);
66   my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB);
67   my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB);
68   my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm);
69   my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm);
70 }
71 
72 /* lazy barrier init */
73 libxsmm_barrier_init(handle->barrier, ltid);
74 
75 if (use_2d_blocking == 1) {
76   if (BF > 1) {
77     for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) {
78       for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) {
79         for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) {
80           /* Initialize intermediate f32 tensor */
81           if ( ifm1 == 0 ) {
82 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
83             for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) {
84               LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk );
85             }
86 #else
87             memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float));
88 #endif
89           }
90           batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
91               &LIBXSMM_VLA_ACCESS(4, input,  mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc),
92               &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
93           /* downconvert intermediate f32 tensor to bf 16 and store to final C */
94           if ( ifm1 == BF-1  ) {
95 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
96             if (handle->bk % 32 == 0) {
97               __m512 cur_out_0 = _mm512_setzero_ps();
98               __m512 cur_out_1 = _mm512_setzero_ps();
99 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
100               __mmask16 relumask0;
101               __mmask16 relumask1;
102 #endif
103 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
104               __m512 ones = _mm512_set1_ps(1.0);
105               __m512 halves = _mm512_set1_ps(0.5);
106 #endif
107               for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
108                 for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) {
109                   cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk));
110                   cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk));
111 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
112                   relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ );
113                   relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ );
114                   cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 );
115                   cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 );
116                   LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 );
117                   LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 );
118 #endif
119 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
120                   /* we ar using Pade 7/8 approximation */
121                   cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves);
122                   cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves);
123 #endif
124                   _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0);
125                   _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1);
126                 }
127               }
128             } else {
129               for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
130                 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
131                   float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
132 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
133                   LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (float)0 ) ? 1 : 0);
134                   l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0;
135 #endif
136 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
137                   /* we ar using Pade 7/8 approximation */
138                   l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f;
139 #endif
140                   LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out;
141                 }
142               }
143             }
144 #endif
145             LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32,    mb1,  ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk), &LIBXSMM_VLA_ACCESS(4, output,    mb1,  ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk),handle->bn*handle->bk);
146           }
147         }
148       }
149     }
150   } else {
151     for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) {
152       for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) {
153 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
154         for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
155           for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
156             LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk);
157           }
158         }
159         batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
160             &LIBXSMM_VLA_ACCESS(4, input,  mb1, 0,  0, 0, nBlocksIFm, handle->bn, handle->bc),
161             &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
162 #else
163         batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
164             &LIBXSMM_VLA_ACCESS(4, input,  mb1, 0,  0, 0, nBlocksIFm, handle->bn, handle->bc),
165             &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
166 #endif
167 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
168         if (handle->bk % 32 == 0) {
169           __m512 cur_out_0 = _mm512_setzero_ps();
170           __m512 cur_out_1 = _mm512_setzero_ps();
171 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
172           __mmask16 relumask0;
173           __mmask16 relumask1;
174 #endif
175 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
176           __m512 ones = _mm512_set1_ps(1.0);
177           __m512 halves = _mm512_set1_ps(0.5);
178 #endif
179           for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
180             for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) {
181               cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)));
182               cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)));
183 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
184               relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ );
185               relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ );
186               cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 );
187               cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 );
188               LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 );
189               LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 );
190 #endif
191 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
192               /* we ar using Pade 7/8 approximation */
193               cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves);
194               cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves);
195 #endif
196               _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 ));
197             }
198           }
199         } else {
200           for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
201             for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
202 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
203               libxsmm_bfloat16_hp t;
204 #endif
205               libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
206 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
207               LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1);
208               l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out);
209 #endif
210 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
211               /* we ar using Pade 7/8 approximation */
212               t.i[1] = l_cur_out;
213               t.i[0] = 0;
214               t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f;
215               l_cur_out = t.i[1];
216 #endif
217               LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out;
218             }
219           }
220         }
221 #endif
222       }
223     }
224   }
225 } else {
226   if (BF > 1) {
227     for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) {
228       for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) {
229         mb1  = mb1ofm1%nBlocksMB;
230         ofm1 = mb1ofm1/nBlocksMB;
231         /* Initialize intermediate f32 tensor */
232         if ( ifm1 == 0 ) {
233 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
234           for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) {
235             LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm, handle->bn, handle->bk), handle->bk );
236           }
237 #else
238           memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float));
239 #endif
240         }
241         batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
242             &LIBXSMM_VLA_ACCESS(4, input,  mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc),
243             &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
244         /* downconvert intermediate f32 tensor to bf 16 and store to final C */
245         if ( ifm1 == BF-1  ) {
246 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
247           if (handle->bk % 32 == 0) {
248             __m512 cur_out_0 = _mm512_setzero_ps();
249             __m512 cur_out_1 = _mm512_setzero_ps();
250 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
251             __mmask16 relumask0;
252             __mmask16 relumask1;
253 #endif
254 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
255             __m512 ones = _mm512_set1_ps(1.0);
256             __m512 halves = _mm512_set1_ps(0.5);
257 #endif
258             for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
259               for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) {
260                 cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk));
261                 cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk));
262 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
263                 relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ );
264                 relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ );
265                 cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 );
266                 cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 );
267                 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 );
268                 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 );
269 #endif
270 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
271                 /* we ar using Pade 7/8 approximation */
272                 cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves);
273                 cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves);
274 #endif
275                 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0);
276                 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4,  output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1);
277               }
278             }
279           } else {
280             for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
281               for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
282                 float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
283 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
284                 LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > 0.0 ) ? 1 : 0);
285                 l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0;
286 #endif
287 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
288                 /* we ar using Pade 7/8 approximation */
289                 l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f;
290 #endif
291                 LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out;
292               }
293             }
294           }
295 #endif
296           LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32,    mb1,  ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &LIBXSMM_VLA_ACCESS(4, output,    mb1,  ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk);
297         }
298       }
299     }
300   } else {
301     for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) {
302       mb1  = mb1ofm1%nBlocksMB;
303       ofm1 = mb1ofm1/nBlocksMB;
304 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS
305       for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) {
306         for ( ofm2 = 0; ofm2 <handle->bk; ++ofm2 ) {
307           LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk);
308         }
309       }
310       batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
311           &LIBXSMM_VLA_ACCESS(4, input,  mb1, 0,  0, 0, nBlocksIFm, handle->bn, handle->bc),
312           &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
313 #else
314       batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb),
315           &LIBXSMM_VLA_ACCESS(4, input,  mb1, 0,  0, 0, nBlocksIFm, handle->bn, handle->bc),
316           &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks);
317 #endif
318 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE
319       if (handle->bk % 32 == 0) {
320         __m512 cur_out_0 = _mm512_setzero_ps();
321         __m512 cur_out_1 = _mm512_setzero_ps();
322 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
323         __mmask16 relumask0;
324         __mmask16 relumask1;
325 #endif
326 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
327         __m512 ones = _mm512_set1_ps(1.0);
328         __m512 halves = _mm512_set1_ps(0.5);
329 #endif
330         for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
331           for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) {
332             cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)));
333             cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)));
334 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
335             relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ );
336             relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ );
337             cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 );
338             cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 );
339             LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 );
340             LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 );
341 #endif
342 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
343             /* we ar using Pade 7/8 approximation */
344             cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves);
345             cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves);
346 #endif
347             _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4,  output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 ));
348           }
349         }
350       } else {
351         for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
352           for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
353 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
354             libxsmm_bfloat16_hp t;
355 #endif
356             libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
357 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU
358             LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1);
359             l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out);
360 #endif
361 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID
362             /* we ar using Pade 7/8 approximation */
363             t.i[1] = l_cur_out;
364             t.i[0] = 0;
365             t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f;
366             l_cur_out = t.i[1];
367 #endif
368             LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out;
369           }
370         }
371       }
372 
373 #endif
374     }
375   }
376 }
377 
378 libxsmm_barrier_wait(handle->barrier, ltid);
379 
380