1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.) 10 ******************************************************************************/ 11 12 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16) 13 # define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16)) 14 #if 1 15 # define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A) 16 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) 17 # define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16))) 18 #else 19 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) 20 # define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16))) 21 #endif 22 #else 23 # define _mm512_load_act(A) _mm512_loadu_ps(A) 24 # define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B) 25 # define _mm512_store_act(A,B) _mm512_storeu_ps(A,B) 26 #endif 27 28 /* size variables, all const */ 29 const int nImg = handle->desc.N; 30 const int ifh = handle->desc.H; 31 const int ifw = handle->desc.W; 32 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG) 33 const int sh = handle->desc.u; 34 const int sw = handle->desc.v; 35 #endif 36 const int ofh = handle->ofh; 37 const int ofw = handle->ofw; 38 const int iph = handle->desc.pad_h_in; 39 const int ipw = handle->desc.pad_w_in; 40 const int oph = handle->desc.pad_h_out; 41 const int opw = handle->desc.pad_w_out; 42 const int ofhp = ofh + 2*oph; 43 const int ofwp = ofw + 2*opw; 44 const int ifhp = ifh + 2*iph; 45 const int ifwp = ifw + 2*ipw; 46 /* here we assume that input and output blocking is similar */ 47 const int nBlocksFm = handle->blocksifm; 48 49 /* computing first logical thread */ 50 const int ltid = tid - start_thread; 51 /* number of tasks that could be run in parallel */ 52 const int work = nImg * nBlocksFm; 53 /* compute chunk size */ 54 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 55 /* compute thr_begin and thr_end */ 56 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 57 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 58 59 /* loop variables */ 60 int img = 0; 61 int fm = 0; 62 int imgfm = 0; 63 int ho = 0; 64 int wo = 0; 65 int hi = 0; 66 int wi = 0; 67 int v = 0; 68 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG) 69 int kh = 0; 70 int kw = 0; 71 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16) 72 float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S); 73 #else 74 element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S); 75 #endif 76 #endif 77 78 /* multi-dim arrays declaration */ 79 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16) 80 float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid); 81 LIBXSMM_VLA_DECL(3, float, lcl_dinput, lcl_buffer_ptr, ifw, 16); 82 #else 83 element_output_type* lcl_buffer_ptr = ((element_input_type*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid); 84 LIBXSMM_VLA_DECL(3, element_input_type, lcl_dinput, lcl_buffer_ptr, ifw, 16); 85 #endif 86 LIBXSMM_VLA_DECL(5, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksFm, ifhp, ifwp, 16); 87 LIBXSMM_VLA_DECL(5, const element_output_type, doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 16); 88 #if defined(LIBXSMM_DNN_POOLING_BWD_MAX) 89 LIBXSMM_VLA_DECL(5, const element_mask_type, mask, (element_mask_type* )handle->mask->data, nBlocksFm, ofh, ofw, 16); 90 #endif 91 92 /* lazy barrier init */ 93 libxsmm_barrier_init(handle->barrier, ltid); 94 95 for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) { 96 img = imgfm / nBlocksFm; 97 fm = imgfm % nBlocksFm; 98 99 for ( v = 0; v < ifh*ifw*16; v += 16 ) { 100 _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() ); 101 } 102 103 #if defined(LIBXSMM_DNN_POOLING_BWD_MAX) 104 for ( ho = oph; ho < (ofh+oph); ho++ ) { 105 for ( wo = opw; wo < (ofw+opw); wo++ ) { 106 const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 16); 107 const element_mask_type* mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask, img, fm, ho-oph, wo-opw, 0, nBlocksFm, ofh, ofw, 16); 108 109 __m512 lcl_vdinput = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr ), lcl_buffer_ptr, 4 ); 110 lcl_vdinput = _mm512_add_ps( lcl_vdinput, _mm512_load_act( doutput_ptr ) ); 111 _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr ), lcl_vdinput, 4 ); 112 } 113 } 114 #endif 115 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG) 116 for ( ho = oph; ho < (ofh+oph); ho++ ) { 117 hi = ((ho-oph) * sh) - handle->desc.pad_h; 118 for ( wo = opw; wo < (ofw+opw); wo++ ) { 119 wi = ((wo-opw) * sw) - handle->desc.pad_w; 120 for ( kh = 0; kh < handle->desc.R; kh++ ) { 121 if (hi+kh < 0 || hi+kh >= ifh) continue; 122 for ( kw = 0; kw < handle->desc.S; kw++ ) { 123 if (wi+kw < 0 || wi+kw >= ifw) { 124 continue; 125 } else { 126 const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm, ho, wo, 0, nBlocksFm, ofhp, ofwp, 16); 127 float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi+kh, wi+kw, 0, ifw, 16); 128 const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size ); 129 const __m512 lcl_dinput_ps = _mm512_loadu_ps( lcl_dinput_ptr ); 130 _mm512_storeu_ps( lcl_dinput_ptr, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr ), recp_pool_size_ps, lcl_dinput_ps ) ); 131 } 132 } 133 } 134 } 135 } 136 #endif 137 138 /* copy the local buffer into dinput activations */ 139 for ( hi = iph; hi < (ifh+iph); hi++ ) { 140 for ( wi = ipw; wi < (ifw+ipw); wi++ ) { 141 element_input_type* dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, fm, hi, wi, 0, nBlocksFm, ifhp, ifwp, 16); 142 float* lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput, hi-iph, wi-ipw, 0, ifw, 16); 143 _mm512_stream_act( dinput_ptr, _mm512_loadu_ps( lcl_dinput_ptr ) ); 144 } 145 } 146 } 147 148 libxsmm_barrier_wait(handle->barrier, ltid); 149 150 # undef _mm512_load_act 151 # undef _mm512_stream_act 152 # undef _mm512_store_act 153 154