1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Alexander Heinecke (Intel Corp.) 10 ******************************************************************************/ 11 12 /* size variables, all const */ 13 /* here we assume that input and output blocking is similar */ 14 const int nBlocksIFm = handle->desc.C / handle->bc; 15 const int nBlocksOFm = handle->desc.K / handle->bk; 16 const int nBlocksMB = handle->desc.N / handle->bn; 17 int lpb = 2; 18 const int bc_lp = handle->bc/lpb; 19 /* const int bc = handle->bc;*/ 20 int use_2d_blocking = handle->fwd_2d_blocking; 21 22 /* computing first logical thread */ 23 const int ltid = tid - start_thread; 24 /* number of tasks that could be run in parallel */ 25 const int work = nBlocksOFm * nBlocksMB; 26 /* compute chunk size */ 27 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 28 /* compute thr_begin and thr_end */ 29 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 30 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 31 32 /* loop variables */ 33 int mb1ofm1 = 0, mb1 = 0, ofm1 = 0, ifm1 = 0; 34 int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; 35 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 36 int mb2 = 0, ofm2 = 0; 37 #endif 38 LIBXSMM_VLA_DECL(4, element_output_type, output, (element_output_type*)handle->reg_output->data, nBlocksOFm, handle->bn, handle->bk); 39 LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, handle->bn, handle->bc); 40 LIBXSMM_VLA_DECL(5, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc_lp, handle->bk, lpb); 41 float* temp_output = (float*)handle->scratch; 42 LIBXSMM_VLA_DECL(4, float, output_f32, (float*) temp_output, nBlocksOFm,handle->bn,handle->bk); 43 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 44 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS 45 LIBXSMM_VLA_DECL(2, const element_input_type, bias, (element_input_type*) handle->reg_bias->data, handle->bk); 46 #endif 47 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 48 LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); 49 LIBXSMM_VLA_DECL(4, __mmask16, relubitmask, (__mmask16*)handle->relumask->data, nBlocksOFm, handle->bn, handle->bk/16); 50 #endif 51 #endif 52 unsigned long long blocks = nBlocksIFm; 53 int CB_BLOCKS = nBlocksIFm, BF = 1; 54 55 BF = handle->fwd_bf; 56 CB_BLOCKS = nBlocksIFm/BF; 57 blocks = CB_BLOCKS; 58 59 if (use_2d_blocking == 1) { 60 row_teams = handle->fwd_row_teams; 61 column_teams = handle->fwd_column_teams; 62 my_col_id = ltid % column_teams; 63 my_row_id = ltid / column_teams; 64 im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); 65 in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); 66 my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); 67 my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); 68 my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); 69 my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); 70 } 71 72 /* lazy barrier init */ 73 libxsmm_barrier_init(handle->barrier, ltid); 74 75 if (use_2d_blocking == 1) { 76 if (BF > 1) { 77 for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { 78 for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { 79 for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { 80 /* Initialize intermediate f32 tensor */ 81 if ( ifm1 == 0 ) { 82 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS 83 for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) { 84 LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm,handle->bn,handle->bk), handle->bk ); 85 } 86 #else 87 memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); 88 #endif 89 } 90 batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 91 &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), 92 &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 93 /* downconvert intermediate f32 tensor to bf 16 and store to final C */ 94 if ( ifm1 == BF-1 ) { 95 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 96 if (handle->bk % 32 == 0) { 97 __m512 cur_out_0 = _mm512_setzero_ps(); 98 __m512 cur_out_1 = _mm512_setzero_ps(); 99 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 100 __mmask16 relumask0; 101 __mmask16 relumask1; 102 #endif 103 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 104 __m512 ones = _mm512_set1_ps(1.0); 105 __m512 halves = _mm512_set1_ps(0.5); 106 #endif 107 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 108 for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { 109 cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)); 110 cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)); 111 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 112 relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); 113 relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); 114 cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); 115 cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); 116 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); 117 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); 118 #endif 119 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 120 /* we ar using Pade 7/8 approximation */ 121 cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves); 122 cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves); 123 #endif 124 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0); 125 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1); 126 } 127 } 128 } else { 129 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 130 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 131 float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 132 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 133 LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > (float)0 ) ? 1 : 0); 134 l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0; 135 #endif 136 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 137 /* we ar using Pade 7/8 approximation */ 138 l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; 139 #endif 140 LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; 141 } 142 } 143 } 144 #endif 145 LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk), &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm,handle->bn,handle->bk),handle->bn*handle->bk); 146 } 147 } 148 } 149 } 150 } else { 151 for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { 152 for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { 153 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS 154 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 155 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 156 LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); 157 } 158 } 159 batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 160 &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), 161 &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 162 #else 163 batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 164 &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), 165 &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 166 #endif 167 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 168 if (handle->bk % 32 == 0) { 169 __m512 cur_out_0 = _mm512_setzero_ps(); 170 __m512 cur_out_1 = _mm512_setzero_ps(); 171 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 172 __mmask16 relumask0; 173 __mmask16 relumask1; 174 #endif 175 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 176 __m512 ones = _mm512_set1_ps(1.0); 177 __m512 halves = _mm512_set1_ps(0.5); 178 #endif 179 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 180 for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { 181 cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk))); 182 cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk))); 183 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 184 relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); 185 relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); 186 cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); 187 cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); 188 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); 189 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); 190 #endif 191 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 192 /* we ar using Pade 7/8 approximation */ 193 cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves); 194 cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves); 195 #endif 196 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 )); 197 } 198 } 199 } else { 200 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 201 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 202 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 203 libxsmm_bfloat16_hp t; 204 #endif 205 libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 206 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 207 LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1); 208 l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out); 209 #endif 210 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 211 /* we ar using Pade 7/8 approximation */ 212 t.i[1] = l_cur_out; 213 t.i[0] = 0; 214 t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f; 215 l_cur_out = t.i[1]; 216 #endif 217 LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; 218 } 219 } 220 } 221 #endif 222 } 223 } 224 } 225 } else { 226 if (BF > 1) { 227 for ( ifm1 = 0; ifm1 < BF; ++ifm1 ) { 228 for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { 229 mb1 = mb1ofm1%nBlocksMB; 230 ofm1 = mb1ofm1/nBlocksMB; 231 /* Initialize intermediate f32 tensor */ 232 if ( ifm1 == 0 ) { 233 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS 234 for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) { 235 LIBXSMM_DNN_CONVERT_BUFFER_BF16_F32( &LIBXSMM_VLA_ACCESS(2, bias, ofm1, 0,handle->bk), &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, 0, nBlocksOFm, handle->bn, handle->bk), handle->bk ); 236 } 237 #else 238 memset(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), 0, handle->bn*handle->bk*sizeof(float)); 239 #endif 240 } 241 batchreduce_kernel( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, ifm1*CB_BLOCKS, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 242 &LIBXSMM_VLA_ACCESS(4, input, mb1, ifm1*CB_BLOCKS, 0, 0, nBlocksIFm, handle->bn, handle->bc), 243 &LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 244 /* downconvert intermediate f32 tensor to bf 16 and store to final C */ 245 if ( ifm1 == BF-1 ) { 246 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 247 if (handle->bk % 32 == 0) { 248 __m512 cur_out_0 = _mm512_setzero_ps(); 249 __m512 cur_out_1 = _mm512_setzero_ps(); 250 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 251 __mmask16 relumask0; 252 __mmask16 relumask1; 253 #endif 254 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 255 __m512 ones = _mm512_set1_ps(1.0); 256 __m512 halves = _mm512_set1_ps(0.5); 257 #endif 258 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 259 for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { 260 cur_out_0 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk)); 261 cur_out_1 = LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk)); 262 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 263 relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); 264 relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); 265 cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); 266 cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); 267 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); 268 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); 269 #endif 270 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 271 /* we ar using Pade 7/8 approximation */ 272 cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves); 273 cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves); 274 #endif 275 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), cur_out_0); 276 _mm512_storeu_ps(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk), cur_out_1); 277 } 278 } 279 } else { 280 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 281 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 282 float l_cur_out = LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 283 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 284 LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( l_cur_out > 0.0 ) ? 1 : 0); 285 l_cur_out = (l_cur_out > (float)0) ? l_cur_out : (float)0; 286 #endif 287 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 288 /* we ar using Pade 7/8 approximation */ 289 l_cur_out = (libxsmm_stanh_pade78( l_cur_out / 2.0f ) + 1.0f) / 2.0f; 290 #endif 291 LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; 292 } 293 } 294 } 295 #endif 296 LIBXSMM_DNN_CONVERT_BUFFER_F32_BF16(&LIBXSMM_VLA_ACCESS(4, output_f32, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), handle->bn*handle->bk); 297 } 298 } 299 } 300 } else { 301 for ( mb1ofm1 = thr_begin; mb1ofm1 < thr_end; ++mb1ofm1 ) { 302 mb1 = mb1ofm1%nBlocksMB; 303 ofm1 = mb1ofm1/nBlocksMB; 304 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_BIAS 305 for ( mb2 = 0; mb2 <handle->bn; ++mb2 ) { 306 for ( ofm2 = 0; ofm2 <handle->bk; ++ofm2 ) { 307 LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = LIBXSMM_VLA_ACCESS(2, bias, ofm1, ofm2, handle->bk); 308 } 309 } 310 batchreduce_kernel_beta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 311 &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), 312 &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 313 #else 314 batchreduce_kernel_zerobeta( &LIBXSMM_VLA_ACCESS(5, filter, ofm1, 0, 0, 0, 0, nBlocksIFm, bc_lp, handle->bk, lpb), 315 &LIBXSMM_VLA_ACCESS(4, input, mb1, 0, 0, 0, nBlocksIFm, handle->bn, handle->bc), 316 &LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, 0, 0, nBlocksOFm, handle->bn, handle->bk), &blocks); 317 #endif 318 #ifndef LIBXSMM_DNN_FC_FWD_FUSE_NONE 319 if (handle->bk % 32 == 0) { 320 __m512 cur_out_0 = _mm512_setzero_ps(); 321 __m512 cur_out_1 = _mm512_setzero_ps(); 322 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 323 __mmask16 relumask0; 324 __mmask16 relumask1; 325 #endif 326 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 327 __m512 ones = _mm512_set1_ps(1.0); 328 __m512 halves = _mm512_set1_ps(0.5); 329 #endif 330 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 331 for ( ofm2 = 0; ofm2 < handle->bk; ofm2 += 32 ) { 332 cur_out_0 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk))); 333 cur_out_1 = LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2+16, nBlocksOFm, handle->bn, handle->bk))); 334 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 335 relumask0 = _mm512_cmp_ps_mask( cur_out_0, _mm512_setzero_ps(), _CMP_GT_OQ ); 336 relumask1 = _mm512_cmp_ps_mask( cur_out_1, _mm512_setzero_ps(), _CMP_GT_OQ ); 337 cur_out_0 = _mm512_mask_blend_ps( relumask0, _mm512_setzero_ps(), cur_out_0 ); 338 cur_out_1 = _mm512_mask_blend_ps( relumask1, _mm512_setzero_ps(), cur_out_1 ); 339 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16, nBlocksOFm, handle->bn, handle->bk/16), relumask0 ); 340 LIBXSMM_INTRINSICS_MM512_STORE_MASK16( &LIBXSMM_VLA_ACCESS(4, relubitmask, mb1, ofm1, mb2, ofm2/16+1, nBlocksOFm, handle->bn, handle->bk/16), relumask1 ); 341 #endif 342 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 343 /* we ar using Pade 7/8 approximation */ 344 cur_out_0 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_0, halves)), ones), halves); 345 cur_out_1 = _mm512_mul_ps(_mm512_add_ps(LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(_mm512_mul_ps(cur_out_1, halves)), ones), halves); 346 #endif 347 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk), LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( cur_out_1, cur_out_0 )); 348 } 349 } 350 } else { 351 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 352 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 353 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 354 libxsmm_bfloat16_hp t; 355 #endif 356 libxsmm_bfloat16 l_cur_out = LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 357 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_RELU 358 LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = (unsigned char)(( (l_cur_out & 0x8000) > 0 ) ? 0 : 1); 359 l_cur_out = (libxsmm_bfloat16)(( (l_cur_out & 0x8000) > 0 ) ? 0 : l_cur_out); 360 #endif 361 #ifdef LIBXSMM_DNN_FC_FWD_FUSE_SIGMOID 362 /* we ar using Pade 7/8 approximation */ 363 t.i[1] = l_cur_out; 364 t.i[0] = 0; 365 t.f = (libxsmm_stanh_pade78( t.f / 2.0f ) + 1.0f) / 2.0f; 366 l_cur_out = t.i[1]; 367 #endif 368 LIBXSMM_VLA_ACCESS(4, output, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; 369 } 370 } 371 } 372 373 #endif 374 } 375 } 376 } 377 378 libxsmm_barrier_wait(handle->barrier, ltid); 379 380