1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_pooling_backward.h"
12 #include "libxsmm_dnn_pooling_forward.h"
13 #include "libxsmm_main.h"
14 
15 
libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc,libxsmm_dnn_err_t * status)16 LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status) {
17   libxsmm_dnn_pooling* handle = 0;
18   int lpb;
19 
20   /* init libxsmm */
21   LIBXSMM_INIT
22 
23   if ( ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
24        ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32))    ) {
25     handle = (libxsmm_dnn_pooling*)malloc(sizeof(libxsmm_dnn_pooling));
26 
27     if (0 != handle) {
28       *status = LIBXSMM_DNN_SUCCESS;
29       /* zero entire content; not only safer but also sets data and code pointers to NULL */
30       memset(handle, 0, sizeof(*handle));
31       /* let's make the description persistent */
32       handle->desc = pooling_desc;
33       /* we need to compute the memory layout given the */
34       *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
35                                                     &(handle->ifmblock), &(handle->ofmblock), &lpb,
36                                                     handle->desc.datatype_in, handle->desc.datatype_out );
37       /* compute the outer blocks */
38       handle->blocksifm = handle->desc.C / handle->ifmblock;
39       handle->blocksofm = handle->desc.C / handle->ofmblock;
40       /* setting ofh and ofw */
41       handle->ofh = (handle->desc.H + 2 * handle->desc.pad_h - handle->desc.R) / handle->desc.u + 1;
42       handle->ofw = (handle->desc.W + 2 * handle->desc.pad_w - handle->desc.S) / handle->desc.v + 1;
43       /* create barrier */
44       handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
45       /* calculate scratch size for local pooling copies of one feature map block per thread */
46       handle->scratch_size = (sizeof(float) * ( (size_t)handle->desc.H + (size_t)LIBXSMM_MAX(handle->desc.pad_h_in, handle->desc.pad_h_out)*2 )
47                                             * ( (size_t)handle->desc.W + (size_t)LIBXSMM_MAX(handle->desc.pad_w_in, handle->desc.pad_w_out)*2 )
48                                             * LIBXSMM_MAX( handle->ofmblock, handle->ifmblock )
49                                             * handle->desc.threads );
50     } else {
51       *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
52     }
53   } else {
54     *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
55   }
56 
57   return handle;
58 }
59 
60 
libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling * handle)61 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle) {
62   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
63 
64   if (0 != handle) {
65     /* Deallocate barrier */
66     if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
67     /* deallocate handle structure */
68     free(/*remove constness*/(libxsmm_dnn_pooling*)handle);
69   } else {
70     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
71   }
72 
73   return status;
74 }
75 
76 
libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)77 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
78   libxsmm_dnn_tensor_datalayout* layout;
79 
80   *status = LIBXSMM_DNN_SUCCESS;
81   layout = 0;
82 
83   if (handle != 0) {
84     layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
85 
86     if (layout != 0) {
87       memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
88       layout->format = handle->desc.buffer_format;
89 
90       if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)  || (type == LIBXSMM_DNN_INPUT)  ||
91            (type == LIBXSMM_DNN_REGULAR_OUTPUT)    || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ||
92            (type == LIBXSMM_DNN_POOLING_MASK)                                                                                  ) {
93         if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
94           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
95             if ( type == LIBXSMM_DNN_POOLING_MASK ) {
96               layout->datatype = handle->desc.datatype_mask;
97             } else {
98               layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
99             }
100             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
101             layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
102 
103             if (0 != layout->dim_type && 0 != layout->dim_size) {
104               layout->num_dims = 5;
105               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
106               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
107               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
108               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
109               layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
110               if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)     || (type == LIBXSMM_DNN_INPUT)   ) {
111                 layout->dim_size[0] = handle->ifmblock;
112                 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
113                 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
114                 layout->dim_size[3] = handle->blocksifm;
115                 layout->dim_size[4] = handle->desc.N;
116               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
117                 layout->dim_size[0] = handle->ofmblock;
118                 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
119                 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
120                 layout->dim_size[3] = handle->blocksofm;
121                 layout->dim_size[4] = handle->desc.N;
122               } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
123                 layout->dim_size[0] = handle->ofmblock;
124                 layout->dim_size[1] = handle->ofw;
125                 layout->dim_size[2] = handle->ofh;
126                 layout->dim_size[3] = handle->blocksofm;
127                 layout->dim_size[4] = handle->desc.N;
128               } else { /* coverity[dead_error_begin] */
129                 free(layout->dim_type);
130                 free(layout->dim_size);
131                 free(layout);
132                 layout = 0; /* make sure a NULL is returned */
133                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
134               }
135             } else {
136               free(layout);
137               layout = 0; /* make sure a NULL is returned */
138               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
139             }
140           } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
141             if ( type == LIBXSMM_DNN_POOLING_MASK ) {
142               layout->datatype = handle->desc.datatype_mask;
143             } else {
144               layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
145             }
146 
147             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
148             layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
149             if (0 != layout->dim_type && 0 != layout->dim_size) {
150               layout->num_dims = 5;
151               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
152               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
153               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
154               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
155               layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
156               if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)     || (type == LIBXSMM_DNN_INPUT)    ) {
157                 layout->dim_size[0] = handle->ifmblock;
158                 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
159                 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
160                 layout->dim_size[3] = handle->blocksifm;
161                 layout->dim_size[4] = handle->desc.N;
162               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
163                 layout->dim_size[0] = handle->ofmblock;
164                 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
165                 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
166                 layout->dim_size[3] = handle->blocksofm;
167                 layout->dim_size[4] = handle->desc.N;
168               } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
169                 layout->dim_size[0] = handle->ofmblock;
170                 layout->dim_size[1] = handle->ofw;
171                 layout->dim_size[2] = handle->ofh;
172                 layout->dim_size[3] = handle->blocksofm;
173                 layout->dim_size[4] = handle->desc.N;
174               } else {
175                 free(layout->dim_type);
176                 free(layout->dim_size);
177                 free(layout);
178                 layout = 0; /* make sure a NULL is returned */
179                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
180               }
181             } else {
182               free(layout);
183               layout = 0; /* make sure a NULL is returned */
184               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
185             }
186           } else {
187             free(layout);
188             layout = 0; /* make sure a NULL is returned */
189             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
190           }
191         } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
192           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
193                ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16))    ) {
194             if ( type == LIBXSMM_DNN_POOLING_MASK ) {
195               layout->datatype = handle->desc.datatype_mask;
196             } else {
197               layout->datatype = handle->desc.datatype_in;
198             }
199             layout->datatype = handle->desc.datatype_in;
200             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
201             layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
202             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
203               layout->num_dims = 4;
204               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
205               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
206               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
207               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
208               if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)     || (type == LIBXSMM_DNN_INPUT)   )   {
209                 layout->dim_size[0] = handle->desc.C;
210                 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
211                 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
212                 layout->dim_size[3] = handle->desc.N;
213               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) )   {
214                 layout->dim_size[0] = handle->desc.C;
215                 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
216                 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
217                 layout->dim_size[3] = handle->desc.N;
218               } else {
219                 free(layout->dim_type);
220                 free(layout->dim_size);
221                 free(layout);
222                 layout = 0; /* make sure a NULL is returned */
223                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
224               }
225             }
226           } else {
227             free(layout);
228             layout = 0; /* make sure a NULL is returned */
229             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
230           }
231         } else {
232           free(layout);
233           layout = 0; /* make sure a NULL is returned */
234           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
235         }
236       } else {
237         free(layout);
238         layout = 0; /* make sure a NULL is returned */
239         *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
240       }
241     } else {
242       *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
243     }
244   }
245   else {
246     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
247   }
248 
249   return layout;
250 }
251 
libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling * handle,libxsmm_dnn_err_t * status)252 LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status) {
253   size_t l_scratch_size = 0;
254   *status = LIBXSMM_DNN_SUCCESS;
255 
256   if (0 != handle) {
257     l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
258   } else {
259     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
260   }
261 
262   return l_scratch_size;
263 }
264 
265 
libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling * handle,const void * scratch)266 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch) {
267   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
268   uintptr_t address = (uintptr_t)scratch;
269   size_t offset = 0;
270 
271   if (scratch == 0) {
272     status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
273     return status;
274   }
275 
276   if (0 != handle) {
277     /* align the internal scratch buffer if needed */
278     if (address % 64 == 0) {
279       handle->scratch = (void*)address;
280     } else {
281       offset = (64 - address % 64);
282       handle->scratch = (void*)(address+offset);
283     }
284   } else {
285     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
286   }
287 
288   return status;
289 }
290 
291 
libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling * handle)292 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle) {
293   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
294 
295   if (0 != handle) {
296     handle->scratch = 0;
297   } else {
298     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
299   }
300 
301   return status;
302 }
303 
304 
libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)305 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
306   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
307 
308   /* check for tensor type */
309   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)         && (type != LIBXSMM_DNN_GRADIENT_INPUT)         &&
310        (type != LIBXSMM_DNN_REGULAR_OUTPUT)        && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)        &&
311        (type != LIBXSMM_DNN_POOLING_MASK)                                                             ) {
312     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
313     return status;
314   }
315 
316   if (handle != 0 && tensor != 0) {
317     libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_pooling_create_tensor_datalayout(handle, type, &status);
318 
319     if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
320       if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
321         handle->reg_input = (libxsmm_dnn_tensor*)tensor;
322       } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
323         handle->grad_input = (libxsmm_dnn_tensor*)tensor;
324       } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
325         handle->reg_output = (libxsmm_dnn_tensor*)tensor;
326       } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
327         handle->grad_output = (libxsmm_dnn_tensor*)tensor;
328       } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
329         handle->mask = (libxsmm_dnn_tensor*)tensor;
330       } else {
331         /* cannot happen */
332       }
333     } else {
334       status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
335     }
336 
337     libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
338   }
339   else {
340     status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
341   }
342 
343   return status;
344 }
345 
346 
libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)347 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
348   libxsmm_dnn_tensor* return_tensor = 0;
349 
350   *status = LIBXSMM_DNN_SUCCESS;
351 
352   /* check for tensor type */
353   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)         && (type != LIBXSMM_DNN_GRADIENT_INPUT)         &&
354        (type != LIBXSMM_DNN_REGULAR_OUTPUT)        && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)        &&
355        (type != LIBXSMM_DNN_POOLING_MASK)                                                              ) {
356     *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
357     return return_tensor;
358   }
359 
360   if (handle != 0) {
361     if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
362       return_tensor = handle->reg_input;
363     } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
364       return_tensor = handle->grad_input;
365     } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
366       return_tensor = handle->reg_output;
367     } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
368       return_tensor = handle->grad_output;
369     } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
370       return_tensor = handle->mask;
371     } else {
372       /* cannot happen */
373     }
374   } else {
375     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
376   }
377 
378   return return_tensor;
379 }
380 
381 
libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type)382 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type) {
383   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
384 
385   /* check for tensor type */
386   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)         && (type != LIBXSMM_DNN_GRADIENT_INPUT)         &&
387        (type != LIBXSMM_DNN_REGULAR_OUTPUT)        && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)        &&
388        (type != LIBXSMM_DNN_POOLING_MASK)                                                              ) {
389     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
390     return status;
391   }
392 
393   if (handle != 0) {
394     if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
395       handle->reg_input = 0;
396     } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
397       handle->grad_input = 0;
398     } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
399       handle->reg_output = 0;
400     } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
401       handle->grad_output = 0;
402     } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
403       handle->mask = 0;
404     } else {
405       /* cannot happen */
406     }
407   } else {
408     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
409   }
410 
411   return status;
412 }
413 
414 
libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)415 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind,
416   /*unsigned*/int start_thread, /*unsigned*/int tid) {
417   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
418 
419   if (0 != handle) {
420     switch (kind) {
421       case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
422         switch (handle->desc.buffer_format) {
423           case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
424             status = libxsmm_dnn_pooling_st_fwd_custom( handle, start_thread, tid );
425           } break;
426           default: {
427             status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
428           }
429         }
430       } break;
431       case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
432         switch (handle->desc.buffer_format) {
433           case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
434             status = libxsmm_dnn_pooling_st_bwd_custom( handle, start_thread, tid );
435           } break;
436           default: {
437             status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
438           }
439         }
440       } break;
441       default: {
442         status = LIBXSMM_DNN_ERR_INVALID_KIND;
443       }
444     }
445   }
446   else {
447     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
448   }
449 
450   return status;
451 }
452 
453