1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 
12 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16)
13 # define _mm512_load_act(A)     _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16))
14 #if 1
15 # define _mm512_roundbf16rne(A) LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(A)
16 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
17 # define _mm512_store_act(A,B)  _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
18 #else
19 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
20 # define _mm512_store_act(A,B)  _mm256_storeu_si256((__m256i*)(A),_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
21 #endif
22 #else
23 # define _mm512_load_act(A)     _mm512_loadu_ps(A)
24 # define _mm512_stream_act(A,B) LIBXSMM_INTRINSICS_MM512_STREAM_PS(A,B)
25 # define _mm512_store_act(A,B)  _mm512_storeu_ps(A,B)
26 #endif
27 
28 /* size variables, all const */
29 const int nImg = handle->desc.N;
30 const int ifh = handle->desc.H;
31 const int ifw = handle->desc.W;
32 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG)
33 const int sh = handle->desc.u;
34 const int sw = handle->desc.v;
35 #endif
36 const int ofh = handle->ofh;
37 const int ofw = handle->ofw;
38 const int iph = handle->desc.pad_h_in;
39 const int ipw = handle->desc.pad_w_in;
40 const int oph = handle->desc.pad_h_out;
41 const int opw = handle->desc.pad_w_out;
42 const int ofhp = ofh + 2*oph;
43 const int ofwp = ofw + 2*opw;
44 const int ifhp = ifh + 2*iph;
45 const int ifwp = ifw + 2*ipw;
46 /* here we assume that input and output blocking is similar */
47 const int nBlocksFm = handle->blocksifm;
48 
49 /* computing first logical thread */
50 const int ltid = tid - start_thread;
51 /* number of tasks that could be run in parallel */
52 const int work = nImg * nBlocksFm;
53 /* compute chunk size */
54 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
55 /* compute thr_begin and thr_end */
56 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
57 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
58 
59 /* loop variables */
60 int img = 0;
61 int fm = 0;
62 int imgfm = 0;
63 int ho = 0;
64 int wo = 0;
65 int hi = 0;
66 int wi = 0;
67 int v = 0;
68 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG)
69 int kh = 0;
70 int kw = 0;
71 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16)
72 float recp_pool_size = 1.0f/((float)handle->desc.R*(float)handle->desc.S);
73 #else
74 element_input_type recp_pool_size = 1.0f/((element_input_type)handle->desc.R*(element_input_type)handle->desc.S);
75 #endif
76 #endif
77 
78 /* multi-dim arrays declaration */
79 #if defined(LIBXSMM_DNN_POOLING_BWD_BF16)
80 float* lcl_buffer_ptr = ((float*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid);
81 LIBXSMM_VLA_DECL(3,       float, lcl_dinput, lcl_buffer_ptr,                                                    ifw, 16);
82 #else
83 element_output_type* lcl_buffer_ptr = ((element_input_type*)handle->scratch)+((size_t)ifh*(size_t)ifw*(size_t)16*(size_t)ltid);
84 LIBXSMM_VLA_DECL(3,       element_input_type, lcl_dinput, lcl_buffer_ptr,                                                    ifw, 16);
85 #endif
86 LIBXSMM_VLA_DECL(5,       element_input_type,     dinput, (element_input_type* )handle->grad_input->data,  nBlocksFm, ifhp, ifwp, 16);
87 LIBXSMM_VLA_DECL(5, const element_output_type,   doutput, (element_output_type*)handle->grad_output->data, nBlocksFm, ofhp, ofwp, 16);
88 #if defined(LIBXSMM_DNN_POOLING_BWD_MAX)
89 LIBXSMM_VLA_DECL(5, const  element_mask_type,        mask, (element_mask_type* )handle->mask->data,        nBlocksFm,  ofh,  ofw, 16);
90 #endif
91 
92 /* lazy barrier init */
93 libxsmm_barrier_init(handle->barrier, ltid);
94 
95 for (imgfm = thr_begin; imgfm < thr_end; ++imgfm) {
96   img = imgfm / nBlocksFm;
97   fm = imgfm % nBlocksFm;
98 
99   for ( v = 0; v < ifh*ifw*16; v += 16 ) {
100     _mm512_storeu_ps( &(lcl_buffer_ptr[v]), _mm512_setzero_ps() );
101   }
102 
103 #if defined(LIBXSMM_DNN_POOLING_BWD_MAX)
104   for ( ho = oph; ho < (ofh+oph); ho++ ) {
105     for ( wo = opw; wo < (ofw+opw); wo++ ) {
106       const element_output_type* doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput, img, fm,     ho,     wo, 0, nBlocksFm, ofhp, ofwp, 16);
107       const element_mask_type*      mask_ptr = &LIBXSMM_VLA_ACCESS(5, mask,    img, fm, ho-oph, wo-opw, 0, nBlocksFm,  ofh,  ofw, 16);
108 
109       __m512 lcl_vdinput = _mm512_i32gather_ps( _mm512_loadu_si512( mask_ptr ), lcl_buffer_ptr, 4 );
110       lcl_vdinput = _mm512_add_ps( lcl_vdinput, _mm512_load_act( doutput_ptr ) );
111       _mm512_i32scatter_ps( lcl_buffer_ptr, _mm512_loadu_si512( mask_ptr ), lcl_vdinput, 4 );
112     }
113   }
114 #endif
115 #if defined(LIBXSMM_DNN_POOLING_BWD_AVG)
116   for ( ho = oph; ho < (ofh+oph); ho++ ) {
117     hi = ((ho-oph) * sh) - handle->desc.pad_h;
118     for ( wo = opw; wo < (ofw+opw); wo++ ) {
119       wi = ((wo-opw) * sw) - handle->desc.pad_w;
120       for ( kh = 0; kh < handle->desc.R; kh++ ) {
121         if (hi+kh < 0 || hi+kh >= ifh) continue;
122         for ( kw = 0; kw < handle->desc.S; kw++ ) {
123           if (wi+kw < 0 || wi+kw >= ifw) {
124             continue;
125           } else {
126             const element_output_type*   doutput_ptr = &LIBXSMM_VLA_ACCESS(5, doutput,    img, fm,    ho,    wo, 0, nBlocksFm, ofhp, ofwp, 16);
127                   float*              lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput,          hi+kh, wi+kw, 0,                   ifw, 16);
128             const __m512 recp_pool_size_ps = _mm512_set1_ps( recp_pool_size );
129             const __m512 lcl_dinput_ps = _mm512_loadu_ps( lcl_dinput_ptr );
130             _mm512_storeu_ps( lcl_dinput_ptr, _mm512_fmadd_ps( _mm512_load_act( doutput_ptr ), recp_pool_size_ps, lcl_dinput_ps ) );
131           }
132         }
133       }
134     }
135   }
136 #endif
137 
138   /* copy the local buffer into dinput activations */
139   for ( hi = iph; hi < (ifh+iph); hi++ ) {
140     for ( wi = ipw; wi < (ifw+ipw); wi++ ) {
141       element_input_type*     dinput_ptr = &LIBXSMM_VLA_ACCESS(5, dinput,     img, fm,        hi,        wi, 0, nBlocksFm, ifhp, ifwp, 16);
142       float*              lcl_dinput_ptr = &LIBXSMM_VLA_ACCESS(3, lcl_dinput,             hi-iph,    wi-ipw, 0,                   ifw, 16);
143       _mm512_stream_act( dinput_ptr, _mm512_loadu_ps( lcl_dinput_ptr ) );
144     }
145   }
146 }
147 
148 libxsmm_barrier_wait(handle->barrier, ltid);
149 
150 # undef _mm512_load_act
151 # undef _mm512_stream_act
152 # undef _mm512_store_act
153 
154