1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11
12
13 #include <stdio.h>
14 #include <omp.h>
15 #include <math.h>
16 #include "PoolingXSMM.hpp"
17
18 #define VLEN 16
19
PoolXSMM(PoolImplParams * gp,int engine)20 PoolXSMM::PoolXSMM(PoolImplParams *gp, int engine) : PoolImpl(gp, engine)
21 {
22 pooling_desc.N = gp->batch_size/NUM_NUMA_NODES;
23 pooling_desc.C = gp->nInput;
24 pooling_desc.H = gp->iHeight;
25 pooling_desc.W = gp->iWidth;
26 pooling_desc.u = gp->stride_h;
27 pooling_desc.v = gp->stride_w;
28 pooling_desc.R = gp->kh;
29 pooling_desc.S = gp->kw;
30 pooling_desc.pad_h = gp->pad_h;
31 pooling_desc.pad_w = gp->pad_w;
32 pooling_desc.pad_h_in = gp->ipad_h;
33 pooling_desc.pad_w_in = gp->ipad_w;
34 pooling_desc.pad_h_out = gp->opad_h;
35 pooling_desc.pad_w_out = gp->opad_w;
36 pooling_desc.threads = gp->num_threads/NUM_NUMA_NODES;
37
38 if(gp->in_data_type == DT_FLOAT && gp->out_data_type == DT_FLOAT)
39 {
40 pooling_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
41 pooling_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
42 }
43 else if(gp->in_data_type == DT_BF16 && gp->out_data_type == DT_BF16)
44 {
45 pooling_desc.datatype_in = LIBXSMM_DNN_DATATYPE_BF16;
46 pooling_desc.datatype_out = LIBXSMM_DNN_DATATYPE_BF16;
47 }
48
49 pooling_desc.datatype_mask = LIBXSMM_DNN_DATATYPE_I32;
50 pooling_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
51
52 if(gp->pool_mode == MAX)
53 pooling_desc.pooling_type = LIBXSMM_DNN_POOLING_MAX;
54 else if(gp->pool_mode == AVE)
55 pooling_desc.pooling_type = LIBXSMM_DNN_POOLING_AVG;
56
57 for(int n=0; n<NUM_NUMA_NODES; n++)
58 {
59 libxsmm_handle[n] = libxsmm_dnn_create_pooling( pooling_desc, &status );
60 CHKERR_LIBXSMM_DNN( status );
61 }
62 }
63
forwardPropagate(TensorBuf * inpb,TensorBuf * outpb,int * mask,int tid)64 void PoolXSMM::forwardPropagate(TensorBuf *inpb, TensorBuf *outpb, int *mask, int tid)
65 {
66 int ifh = gp->iHeight;
67 int ifw = gp->iWidth;
68 int iph = gp->ipad_h;
69 int ipw = gp->ipad_w;
70 int ifhp = ifh + 2*iph;
71 int ifwp = ifw + 2*ipw;
72 int ofh = gp->oHeight;
73 int ofw = gp->oWidth;
74 int oph = gp->opad_h;
75 int opw = gp->opad_w;
76 int ofhp = ofh + 2*oph;
77 int ofwp = ofw + 2*opw;
78
79 void *input[NUM_NUMA_NODES];
80 void *output[NUM_NUMA_NODES];
81 int *pool_mask[NUM_NUMA_NODES];
82
83 int imoff = pooling_desc.N * pooling_desc.C * ifhp * ifwp;
84 if(gp->in_data_type == DT_FLOAT)
85 imoff *= sizeof(float);
86 else if(gp->in_data_type == DT_BF16)
87 imoff *= sizeof(libxsmm_bfloat16);
88 input[0] = inpb->getBuffer();
89 for(int n=1; n<NUM_NUMA_NODES; n++)
90 input[n] = input[n-1] + imoff;
91
92 imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
93 if(gp->in_data_type == DT_FLOAT)
94 imoff *= sizeof(float);
95 else if(gp->in_data_type == DT_BF16)
96 imoff *= sizeof(libxsmm_bfloat16);
97 output[0] = outpb->getBuffer();
98 for(int n=1; n<NUM_NUMA_NODES; n++)
99 output[n] = output[n-1] + imoff;
100
101 imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
102 pool_mask[0] = mask;
103 for(int n=1; n<NUM_NUMA_NODES; n++)
104 pool_mask[n] = pool_mask[n-1] + imoff;
105
106 void **sptrptr = scratchp->getBufferPtr();
107
108 for(int n=0; n<NUM_NUMA_NODES; n++)
109 {
110 if(libxsmm_input[n] == NULL && libxsmm_mask[n] == NULL && libxsmm_output[n] == NULL)
111 {
112 libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_REGULAR_INPUT, &status );
113 CHKERR_LIBXSMM_DNN( status );
114 libxsmm_input[n] = libxsmm_dnn_link_tensor( libxsmm_layout, input[n], &status ); CHKERR_LIBXSMM_DNN( status );
115 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
116 CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor( libxsmm_handle[n], libxsmm_input[n], LIBXSMM_DNN_REGULAR_INPUT));
117
118 libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_REGULAR_OUTPUT, &status );
119 CHKERR_LIBXSMM_DNN( status );
120 libxsmm_output[n] = libxsmm_dnn_link_tensor( libxsmm_layout, output[n], &status ); CHKERR_LIBXSMM_DNN( status );
121 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
122 CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_output[n], LIBXSMM_DNN_REGULAR_OUTPUT));
123
124 libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_POOLING_MASK, &status );
125 CHKERR_LIBXSMM_DNN( status );
126 libxsmm_mask[n] = libxsmm_dnn_link_tensor( libxsmm_layout, (void*)pool_mask[n], &status );
127 CHKERR_LIBXSMM_DNN( status );
128 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
129 CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_mask[n], LIBXSMM_DNN_POOLING_MASK ) );
130 }
131 }
132
133 if(sptrptr == NULL)
134 {
135 sptrptr = (void**)libxsmm_aligned_malloc(NUM_NUMA_NODES*sizeof(void*), 2097152);
136 scratchp->setBufferPtr(sptrptr);
137 }
138
139 if(prev_scratch_size == 0)
140 prev_scratch_size = scratchp->getBufferSize();
141
142 if(!updated_scratch_fwd || prev_scratch_size != scratchp->getBufferSize())
143 {
144 int max_size=0;
145
146 for(int n=0; n<NUM_NUMA_NODES; n++)
147 {
148 if(sptrptr[n] == NULL)
149 {
150 long long mysize = libxsmm_dnn_pooling_get_scratch_size( libxsmm_handle[n], &status );
151 CHKERR_LIBXSMM_DNN( status );
152 sptrptr[n] = libxsmm_aligned_scratch( mysize, 2097152 );
153 max_size = mysize;
154
155 #ifdef USE_MLSL
156 if(MLSL::Environment::GetEnv().GetProcessIdx() == 0)
157 #endif
158 printf("%s allocated %lld bytes for scratch @ %p\n",nname.c_str(), mysize, sptrptr[n]);
159 }
160 else
161 {
162 long long int ssize = scratchp->getBufferSize();
163 long long int mysize = libxsmm_dnn_pooling_get_scratch_size( libxsmm_handle[n], &status );
164
165 CHKERR_LIBXSMM_DNN( status );
166
167 if(ssize < mysize)
168 {
169 libxsmm_free(sptrptr[n]);
170 sptrptr[n] = (void*)libxsmm_aligned_malloc(mysize, 2097152);
171 max_size = mysize;
172
173 #ifdef USE_MLSL
174 if(MLSL::Environment::GetEnv().GetProcessIdx() == 0)
175 #endif
176 printf("%s allocated %lld bytes for scratch @ %p, prev size was %lld bytes\n",nname.c_str(), mysize, sptrptr[n], ssize);
177 }
178 else
179 max_size = ssize;
180 }
181 }
182 scratchp->setBufferSize(max_size);
183
184 for(int n=0; n<NUM_NUMA_NODES; n++)
185 CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_scratch( libxsmm_handle[n], sptrptr[n] ) );
186 updated_scratch_fwd = true;
187 prev_scratch_size = scratchp->getBufferSize();
188 }
189
190 #if defined(_OPENMP)
191 #pragma omp parallel
192 #endif
193 {
194 #if defined(_OPENMP)
195 const int tid = omp_get_thread_num();
196 #else
197 const int tid = 0;
198 #endif
199 int ntps = gp->num_threads/NUM_NUMA_NODES;
200 int n = tid/ntps;
201 CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_execute_st( libxsmm_handle[n], LIBXSMM_DNN_COMPUTE_KIND_FWD, n*ntps, tid ) );
202 }
203 }
204
backPropagate(TensorBuf * deloutpb,int * mask,TensorBuf * delinpb,int tid)205 void PoolXSMM::backPropagate(TensorBuf *deloutpb, int *mask, TensorBuf *delinpb, int tid)
206 {
207 int ifh = gp->iHeight;
208 int ifw = gp->iWidth;
209 int iph = gp->ipad_h;
210 int ipw = gp->ipad_w;
211 int ifhp = ifh + 2*iph;
212 int ifwp = ifw + 2*ipw;
213 int ofh = gp->oHeight;
214 int ofw = gp->oWidth;
215 int oph = gp->opad_h;
216 int opw = gp->opad_w;
217 int ofhp = ofh + 2*oph;
218 int ofwp = ofw + 2*opw;
219
220 void *deloutput[NUM_NUMA_NODES];
221 void *delinput[NUM_NUMA_NODES];
222 int* pool_mask[NUM_NUMA_NODES];
223
224 int imoff = pooling_desc.N * pooling_desc.C * ifhp * ifwp;
225 if(gp->in_data_type == DT_FLOAT)
226 imoff *= sizeof(float);
227 else if(gp->in_data_type == DT_BF16)
228 imoff *= sizeof(libxsmm_bfloat16);
229 delinput[0] = delinpb->getBuffer();
230 for(int n=1; n<NUM_NUMA_NODES; n++)
231 delinput[n] = delinput[n-1] + imoff;
232
233 imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
234 if(gp->in_data_type == DT_FLOAT)
235 imoff *= sizeof(float);
236 else if(gp->in_data_type == DT_BF16)
237 imoff *= sizeof(libxsmm_bfloat16);
238 deloutput[0] = deloutpb->getBuffer();
239 for(int n=1; n<NUM_NUMA_NODES; n++)
240 deloutput[n] = deloutput[n-1] + imoff;
241
242 imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
243 pool_mask[0] = mask;
244 for(int n=1; n<NUM_NUMA_NODES; n++)
245 pool_mask[n] = pool_mask[n-1] + imoff;
246
247 void **sptrptr = scratchp->getBufferPtr();
248 if(!updated_scratch_bwd)
249 {
250 for(int n=0; n<NUM_NUMA_NODES; n++)
251 CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_scratch( libxsmm_handle[n], sptrptr[n] ) );
252 updated_scratch_bwd = true;
253 }
254
255 for(int n=0; n<NUM_NUMA_NODES; n++)
256 {
257 if(libxsmm_deloutput[n] == NULL && libxsmm_delinput[n] == NULL)
258 {
259 libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout(libxsmm_handle[n], LIBXSMM_DNN_GRADIENT_OUTPUT, &status);
260 CHKERR_LIBXSMM_DNN( status );
261 libxsmm_deloutput[n] = libxsmm_dnn_link_tensor( libxsmm_layout, deloutput[n], &status );
262 CHKERR_LIBXSMM_DNN( status );
263 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
264 CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_deloutput[n], LIBXSMM_DNN_GRADIENT_OUTPUT));
265
266 libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout(libxsmm_handle[n], LIBXSMM_DNN_GRADIENT_INPUT, &status);
267 CHKERR_LIBXSMM_DNN( status );
268 libxsmm_delinput[n] = libxsmm_dnn_link_tensor( libxsmm_layout, delinput[n], &status );
269 CHKERR_LIBXSMM_DNN( status );
270 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
271 CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_delinput[n], LIBXSMM_DNN_GRADIENT_INPUT));
272 }
273 }
274
275 #if defined(_OPENMP)
276 #pragma omp parallel
277 #endif
278 {
279 #if defined(_OPENMP)
280 const int tid = omp_get_thread_num();
281 #else
282 const int tid = 0;
283 #endif
284 int ntps = gp->num_threads/NUM_NUMA_NODES;
285 int n = tid/ntps;
286 CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_execute_st(libxsmm_handle[n], LIBXSMM_DNN_COMPUTE_KIND_BWD, n*ntps, tid ) );
287 }
288 delinpb->setLayoutType(LIBXSMM_CUSTOM_LAYOUT);
289 }
290