1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #include <stdio.h>
14 #include <omp.h>
15 #include <math.h>
16 #include "PoolingXSMM.hpp"
17 
18 #define VLEN 16
19 
PoolXSMM(PoolImplParams * gp,int engine)20 PoolXSMM::PoolXSMM(PoolImplParams *gp, int engine) : PoolImpl(gp, engine)
21 {
22   pooling_desc.N = gp->batch_size/NUM_NUMA_NODES;
23   pooling_desc.C = gp->nInput;
24   pooling_desc.H = gp->iHeight;
25   pooling_desc.W = gp->iWidth;
26   pooling_desc.u = gp->stride_h;
27   pooling_desc.v = gp->stride_w;
28   pooling_desc.R = gp->kh;
29   pooling_desc.S = gp->kw;
30   pooling_desc.pad_h = gp->pad_h;
31   pooling_desc.pad_w = gp->pad_w;
32   pooling_desc.pad_h_in = gp->ipad_h;
33   pooling_desc.pad_w_in = gp->ipad_w;
34   pooling_desc.pad_h_out = gp->opad_h;
35   pooling_desc.pad_w_out = gp->opad_w;
36   pooling_desc.threads = gp->num_threads/NUM_NUMA_NODES;
37 
38   if(gp->in_data_type == DT_FLOAT && gp->out_data_type == DT_FLOAT)
39   {
40     pooling_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
41     pooling_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
42   }
43   else if(gp->in_data_type == DT_BF16 && gp->out_data_type == DT_BF16)
44   {
45     pooling_desc.datatype_in = LIBXSMM_DNN_DATATYPE_BF16;
46     pooling_desc.datatype_out = LIBXSMM_DNN_DATATYPE_BF16;
47   }
48 
49   pooling_desc.datatype_mask = LIBXSMM_DNN_DATATYPE_I32;
50   pooling_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
51 
52   if(gp->pool_mode == MAX)
53     pooling_desc.pooling_type = LIBXSMM_DNN_POOLING_MAX;
54   else if(gp->pool_mode == AVE)
55     pooling_desc.pooling_type = LIBXSMM_DNN_POOLING_AVG;
56 
57   for(int n=0; n<NUM_NUMA_NODES; n++)
58   {
59     libxsmm_handle[n] = libxsmm_dnn_create_pooling( pooling_desc, &status );
60     CHKERR_LIBXSMM_DNN( status );
61   }
62 }
63 
forwardPropagate(TensorBuf * inpb,TensorBuf * outpb,int * mask,int tid)64 void PoolXSMM::forwardPropagate(TensorBuf *inpb, TensorBuf *outpb, int *mask, int tid)
65 {
66   int ifh = gp->iHeight;
67   int ifw = gp->iWidth;
68   int iph = gp->ipad_h;
69   int ipw = gp->ipad_w;
70   int ifhp = ifh + 2*iph;
71   int ifwp = ifw + 2*ipw;
72   int ofh = gp->oHeight;
73   int ofw = gp->oWidth;
74   int oph = gp->opad_h;
75   int opw = gp->opad_w;
76   int ofhp = ofh + 2*oph;
77   int ofwp = ofw + 2*opw;
78 
79   void *input[NUM_NUMA_NODES];
80   void *output[NUM_NUMA_NODES];
81   int *pool_mask[NUM_NUMA_NODES];
82 
83   int imoff = pooling_desc.N * pooling_desc.C * ifhp * ifwp;
84   if(gp->in_data_type == DT_FLOAT)
85     imoff *= sizeof(float);
86   else if(gp->in_data_type == DT_BF16)
87     imoff *= sizeof(libxsmm_bfloat16);
88   input[0] = inpb->getBuffer();
89   for(int n=1; n<NUM_NUMA_NODES; n++)
90     input[n] = input[n-1] + imoff;
91 
92   imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
93   if(gp->in_data_type == DT_FLOAT)
94     imoff *= sizeof(float);
95   else if(gp->in_data_type == DT_BF16)
96     imoff *= sizeof(libxsmm_bfloat16);
97   output[0] = outpb->getBuffer();
98   for(int n=1; n<NUM_NUMA_NODES; n++)
99     output[n] = output[n-1] + imoff;
100 
101   imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
102   pool_mask[0] = mask;
103   for(int n=1; n<NUM_NUMA_NODES; n++)
104     pool_mask[n] = pool_mask[n-1] + imoff;
105 
106   void **sptrptr = scratchp->getBufferPtr();
107 
108   for(int n=0; n<NUM_NUMA_NODES; n++)
109   {
110     if(libxsmm_input[n] == NULL && libxsmm_mask[n] == NULL && libxsmm_output[n] == NULL)
111     {
112       libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_REGULAR_INPUT, &status );
113       CHKERR_LIBXSMM_DNN( status );
114       libxsmm_input[n]  = libxsmm_dnn_link_tensor( libxsmm_layout, input[n], &status ); CHKERR_LIBXSMM_DNN( status );
115       libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
116       CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor( libxsmm_handle[n], libxsmm_input[n], LIBXSMM_DNN_REGULAR_INPUT));
117 
118       libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_REGULAR_OUTPUT, &status );
119       CHKERR_LIBXSMM_DNN( status );
120       libxsmm_output[n]  = libxsmm_dnn_link_tensor( libxsmm_layout, output[n], &status ); CHKERR_LIBXSMM_DNN( status );
121       libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
122       CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_output[n], LIBXSMM_DNN_REGULAR_OUTPUT));
123 
124       libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout( libxsmm_handle[n], LIBXSMM_DNN_POOLING_MASK, &status );
125       CHKERR_LIBXSMM_DNN( status );
126       libxsmm_mask[n]  = libxsmm_dnn_link_tensor( libxsmm_layout, (void*)pool_mask[n], &status );
127       CHKERR_LIBXSMM_DNN( status );
128       libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
129       CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_mask[n], LIBXSMM_DNN_POOLING_MASK ) );
130     }
131   }
132 
133   if(sptrptr == NULL)
134   {
135     sptrptr = (void**)libxsmm_aligned_malloc(NUM_NUMA_NODES*sizeof(void*), 2097152);
136     scratchp->setBufferPtr(sptrptr);
137   }
138 
139   if(prev_scratch_size == 0)
140     prev_scratch_size = scratchp->getBufferSize();
141 
142   if(!updated_scratch_fwd || prev_scratch_size != scratchp->getBufferSize())
143   {
144     int max_size=0;
145 
146     for(int n=0; n<NUM_NUMA_NODES; n++)
147     {
148       if(sptrptr[n] == NULL)
149       {
150         long long mysize = libxsmm_dnn_pooling_get_scratch_size( libxsmm_handle[n], &status );
151         CHKERR_LIBXSMM_DNN( status );
152         sptrptr[n] = libxsmm_aligned_scratch( mysize, 2097152 );
153         max_size = mysize;
154 
155 #ifdef USE_MLSL
156         if(MLSL::Environment::GetEnv().GetProcessIdx() == 0)
157 #endif
158           printf("%s allocated %lld bytes for scratch @ %p\n",nname.c_str(), mysize, sptrptr[n]);
159       }
160       else
161       {
162         long long int ssize = scratchp->getBufferSize();
163         long long int mysize = libxsmm_dnn_pooling_get_scratch_size( libxsmm_handle[n], &status );
164 
165         CHKERR_LIBXSMM_DNN( status );
166 
167         if(ssize < mysize)
168         {
169           libxsmm_free(sptrptr[n]);
170           sptrptr[n] = (void*)libxsmm_aligned_malloc(mysize, 2097152);
171           max_size = mysize;
172 
173 #ifdef USE_MLSL
174           if(MLSL::Environment::GetEnv().GetProcessIdx() == 0)
175 #endif
176             printf("%s allocated %lld bytes for scratch @ %p, prev size was %lld bytes\n",nname.c_str(), mysize, sptrptr[n], ssize);
177         }
178         else
179           max_size = ssize;
180       }
181     }
182     scratchp->setBufferSize(max_size);
183 
184     for(int n=0; n<NUM_NUMA_NODES; n++)
185       CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_scratch( libxsmm_handle[n], sptrptr[n] ) );
186     updated_scratch_fwd = true;
187     prev_scratch_size = scratchp->getBufferSize();
188   }
189 
190 #if defined(_OPENMP)
191 #pragma omp parallel
192 #endif
193   {
194 #if defined(_OPENMP)
195     const int tid = omp_get_thread_num();
196 #else
197     const int tid = 0;
198 #endif
199     int ntps = gp->num_threads/NUM_NUMA_NODES;
200     int n = tid/ntps;
201     CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_execute_st( libxsmm_handle[n], LIBXSMM_DNN_COMPUTE_KIND_FWD, n*ntps, tid ) );
202   }
203 }
204 
backPropagate(TensorBuf * deloutpb,int * mask,TensorBuf * delinpb,int tid)205 void PoolXSMM::backPropagate(TensorBuf *deloutpb, int *mask, TensorBuf *delinpb, int tid)
206 {
207   int ifh = gp->iHeight;
208   int ifw = gp->iWidth;
209   int iph = gp->ipad_h;
210   int ipw = gp->ipad_w;
211   int ifhp = ifh + 2*iph;
212   int ifwp = ifw + 2*ipw;
213   int ofh = gp->oHeight;
214   int ofw = gp->oWidth;
215   int oph = gp->opad_h;
216   int opw = gp->opad_w;
217   int ofhp = ofh + 2*oph;
218   int ofwp = ofw + 2*opw;
219 
220   void *deloutput[NUM_NUMA_NODES];
221   void *delinput[NUM_NUMA_NODES];
222   int* pool_mask[NUM_NUMA_NODES];
223 
224   int imoff = pooling_desc.N * pooling_desc.C * ifhp * ifwp;
225   if(gp->in_data_type == DT_FLOAT)
226     imoff *= sizeof(float);
227   else if(gp->in_data_type == DT_BF16)
228     imoff *= sizeof(libxsmm_bfloat16);
229   delinput[0] = delinpb->getBuffer();
230   for(int n=1; n<NUM_NUMA_NODES; n++)
231     delinput[n] = delinput[n-1] + imoff;
232 
233   imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
234   if(gp->in_data_type == DT_FLOAT)
235     imoff *= sizeof(float);
236   else if(gp->in_data_type == DT_BF16)
237     imoff *= sizeof(libxsmm_bfloat16);
238   deloutput[0] = deloutpb->getBuffer();
239   for(int n=1; n<NUM_NUMA_NODES; n++)
240     deloutput[n] = deloutput[n-1] + imoff;
241 
242   imoff = pooling_desc.N * pooling_desc.C * ofhp * ofwp;
243   pool_mask[0] = mask;
244   for(int n=1; n<NUM_NUMA_NODES; n++)
245     pool_mask[n] = pool_mask[n-1] + imoff;
246 
247   void **sptrptr = scratchp->getBufferPtr();
248   if(!updated_scratch_bwd)
249   {
250     for(int n=0; n<NUM_NUMA_NODES; n++)
251       CHKERR_LIBXSMM_DNN( libxsmm_dnn_pooling_bind_scratch( libxsmm_handle[n], sptrptr[n] ) );
252     updated_scratch_bwd = true;
253   }
254 
255   for(int n=0; n<NUM_NUMA_NODES; n++)
256   {
257     if(libxsmm_deloutput[n] == NULL && libxsmm_delinput[n] == NULL)
258     {
259       libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout(libxsmm_handle[n], LIBXSMM_DNN_GRADIENT_OUTPUT, &status);
260       CHKERR_LIBXSMM_DNN( status );
261       libxsmm_deloutput[n] = libxsmm_dnn_link_tensor( libxsmm_layout, deloutput[n], &status );
262       CHKERR_LIBXSMM_DNN( status );
263       libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
264       CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_deloutput[n], LIBXSMM_DNN_GRADIENT_OUTPUT));
265 
266       libxsmm_layout = libxsmm_dnn_pooling_create_tensor_datalayout(libxsmm_handle[n], LIBXSMM_DNN_GRADIENT_INPUT, &status);
267       CHKERR_LIBXSMM_DNN( status );
268       libxsmm_delinput[n]  = libxsmm_dnn_link_tensor( libxsmm_layout, delinput[n], &status );
269       CHKERR_LIBXSMM_DNN( status );
270       libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
271       CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_bind_tensor(libxsmm_handle[n], libxsmm_delinput[n], LIBXSMM_DNN_GRADIENT_INPUT));
272     }
273   }
274 
275 #if defined(_OPENMP)
276 #pragma omp parallel
277 #endif
278   {
279 #if defined(_OPENMP)
280     const int tid = omp_get_thread_num();
281 #else
282     const int tid = 0;
283 #endif
284     int ntps = gp->num_threads/NUM_NUMA_NODES;
285     int n = tid/ntps;
286     CHKERR_LIBXSMM_DNN(libxsmm_dnn_pooling_execute_st(libxsmm_handle[n], LIBXSMM_DNN_COMPUTE_KIND_BWD, n*ntps, tid ) );
287   }
288   delinpb->setLayoutType(LIBXSMM_CUSTOM_LAYOUT);
289 }
290