1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_pooling_backward.h"
12 #include "libxsmm_dnn_pooling_forward.h"
13 #include "libxsmm_main.h"
14
15
libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc,libxsmm_dnn_err_t * status)16 LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status) {
17 libxsmm_dnn_pooling* handle = 0;
18 int lpb;
19
20 /* init libxsmm */
21 LIBXSMM_INIT
22
23 if ( ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
24 ((pooling_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (pooling_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
25 handle = (libxsmm_dnn_pooling*)malloc(sizeof(libxsmm_dnn_pooling));
26
27 if (0 != handle) {
28 *status = LIBXSMM_DNN_SUCCESS;
29 /* zero entire content; not only safer but also sets data and code pointers to NULL */
30 memset(handle, 0, sizeof(*handle));
31 /* let's make the description persistent */
32 handle->desc = pooling_desc;
33 /* we need to compute the memory layout given the */
34 *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.C,
35 &(handle->ifmblock), &(handle->ofmblock), &lpb,
36 handle->desc.datatype_in, handle->desc.datatype_out );
37 /* compute the outer blocks */
38 handle->blocksifm = handle->desc.C / handle->ifmblock;
39 handle->blocksofm = handle->desc.C / handle->ofmblock;
40 /* setting ofh and ofw */
41 handle->ofh = (handle->desc.H + 2 * handle->desc.pad_h - handle->desc.R) / handle->desc.u + 1;
42 handle->ofw = (handle->desc.W + 2 * handle->desc.pad_w - handle->desc.S) / handle->desc.v + 1;
43 /* create barrier */
44 handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
45 /* calculate scratch size for local pooling copies of one feature map block per thread */
46 handle->scratch_size = (sizeof(float) * ( (size_t)handle->desc.H + (size_t)LIBXSMM_MAX(handle->desc.pad_h_in, handle->desc.pad_h_out)*2 )
47 * ( (size_t)handle->desc.W + (size_t)LIBXSMM_MAX(handle->desc.pad_w_in, handle->desc.pad_w_out)*2 )
48 * LIBXSMM_MAX( handle->ofmblock, handle->ifmblock )
49 * handle->desc.threads );
50 } else {
51 *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
52 }
53 } else {
54 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
55 }
56
57 return handle;
58 }
59
60
libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling * handle)61 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle) {
62 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
63
64 if (0 != handle) {
65 /* Deallocate barrier */
66 if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
67 /* deallocate handle structure */
68 free(/*remove constness*/(libxsmm_dnn_pooling*)handle);
69 } else {
70 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
71 }
72
73 return status;
74 }
75
76
libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)77 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
78 libxsmm_dnn_tensor_datalayout* layout;
79
80 *status = LIBXSMM_DNN_SUCCESS;
81 layout = 0;
82
83 if (handle != 0) {
84 layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
85
86 if (layout != 0) {
87 memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
88 layout->format = handle->desc.buffer_format;
89
90 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
91 (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ||
92 (type == LIBXSMM_DNN_POOLING_MASK) ) {
93 if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
94 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
95 if ( type == LIBXSMM_DNN_POOLING_MASK ) {
96 layout->datatype = handle->desc.datatype_mask;
97 } else {
98 layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
99 }
100 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
101 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
102
103 if (0 != layout->dim_type && 0 != layout->dim_size) {
104 layout->num_dims = 5;
105 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
106 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
107 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
108 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
109 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
110 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
111 layout->dim_size[0] = handle->ifmblock;
112 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
113 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
114 layout->dim_size[3] = handle->blocksifm;
115 layout->dim_size[4] = handle->desc.N;
116 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
117 layout->dim_size[0] = handle->ofmblock;
118 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
119 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
120 layout->dim_size[3] = handle->blocksofm;
121 layout->dim_size[4] = handle->desc.N;
122 } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
123 layout->dim_size[0] = handle->ofmblock;
124 layout->dim_size[1] = handle->ofw;
125 layout->dim_size[2] = handle->ofh;
126 layout->dim_size[3] = handle->blocksofm;
127 layout->dim_size[4] = handle->desc.N;
128 } else { /* coverity[dead_error_begin] */
129 free(layout->dim_type);
130 free(layout->dim_size);
131 free(layout);
132 layout = 0; /* make sure a NULL is returned */
133 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
134 }
135 } else {
136 free(layout);
137 layout = 0; /* make sure a NULL is returned */
138 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
139 }
140 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
141 if ( type == LIBXSMM_DNN_POOLING_MASK ) {
142 layout->datatype = handle->desc.datatype_mask;
143 } else {
144 layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
145 }
146
147 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
148 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
149 if (0 != layout->dim_type && 0 != layout->dim_size) {
150 layout->num_dims = 5;
151 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
152 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
153 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
154 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
155 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
156 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
157 layout->dim_size[0] = handle->ifmblock;
158 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
159 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
160 layout->dim_size[3] = handle->blocksifm;
161 layout->dim_size[4] = handle->desc.N;
162 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
163 layout->dim_size[0] = handle->ofmblock;
164 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
165 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
166 layout->dim_size[3] = handle->blocksofm;
167 layout->dim_size[4] = handle->desc.N;
168 } else if ( (type == LIBXSMM_DNN_POOLING_MASK) ) {
169 layout->dim_size[0] = handle->ofmblock;
170 layout->dim_size[1] = handle->ofw;
171 layout->dim_size[2] = handle->ofh;
172 layout->dim_size[3] = handle->blocksofm;
173 layout->dim_size[4] = handle->desc.N;
174 } else {
175 free(layout->dim_type);
176 free(layout->dim_size);
177 free(layout);
178 layout = 0; /* make sure a NULL is returned */
179 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
180 }
181 } else {
182 free(layout);
183 layout = 0; /* make sure a NULL is returned */
184 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
185 }
186 } else {
187 free(layout);
188 layout = 0; /* make sure a NULL is returned */
189 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
190 }
191 } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
192 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
193 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
194 if ( type == LIBXSMM_DNN_POOLING_MASK ) {
195 layout->datatype = handle->desc.datatype_mask;
196 } else {
197 layout->datatype = handle->desc.datatype_in;
198 }
199 layout->datatype = handle->desc.datatype_in;
200 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
201 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
202 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
203 layout->num_dims = 4;
204 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
205 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
206 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
207 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
208 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
209 layout->dim_size[0] = handle->desc.C;
210 layout->dim_size[1] = handle->desc.W + (2*handle->desc.pad_w_in);
211 layout->dim_size[2] = handle->desc.H + (2*handle->desc.pad_h_in);
212 layout->dim_size[3] = handle->desc.N;
213 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
214 layout->dim_size[0] = handle->desc.C;
215 layout->dim_size[1] = (handle->ofw) + (2*handle->desc.pad_w_out);
216 layout->dim_size[2] = (handle->ofh) + (2*handle->desc.pad_h_out);
217 layout->dim_size[3] = handle->desc.N;
218 } else {
219 free(layout->dim_type);
220 free(layout->dim_size);
221 free(layout);
222 layout = 0; /* make sure a NULL is returned */
223 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
224 }
225 }
226 } else {
227 free(layout);
228 layout = 0; /* make sure a NULL is returned */
229 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
230 }
231 } else {
232 free(layout);
233 layout = 0; /* make sure a NULL is returned */
234 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
235 }
236 } else {
237 free(layout);
238 layout = 0; /* make sure a NULL is returned */
239 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
240 }
241 } else {
242 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
243 }
244 }
245 else {
246 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
247 }
248
249 return layout;
250 }
251
libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling * handle,libxsmm_dnn_err_t * status)252 LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status) {
253 size_t l_scratch_size = 0;
254 *status = LIBXSMM_DNN_SUCCESS;
255
256 if (0 != handle) {
257 l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
258 } else {
259 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
260 }
261
262 return l_scratch_size;
263 }
264
265
libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling * handle,const void * scratch)266 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch) {
267 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
268 uintptr_t address = (uintptr_t)scratch;
269 size_t offset = 0;
270
271 if (scratch == 0) {
272 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
273 return status;
274 }
275
276 if (0 != handle) {
277 /* align the internal scratch buffer if needed */
278 if (address % 64 == 0) {
279 handle->scratch = (void*)address;
280 } else {
281 offset = (64 - address % 64);
282 handle->scratch = (void*)(address+offset);
283 }
284 } else {
285 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
286 }
287
288 return status;
289 }
290
291
libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling * handle)292 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle) {
293 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
294
295 if (0 != handle) {
296 handle->scratch = 0;
297 } else {
298 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
299 }
300
301 return status;
302 }
303
304
libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)305 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
306 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
307
308 /* check for tensor type */
309 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
310 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
311 (type != LIBXSMM_DNN_POOLING_MASK) ) {
312 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
313 return status;
314 }
315
316 if (handle != 0 && tensor != 0) {
317 libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_pooling_create_tensor_datalayout(handle, type, &status);
318
319 if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
320 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
321 handle->reg_input = (libxsmm_dnn_tensor*)tensor;
322 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
323 handle->grad_input = (libxsmm_dnn_tensor*)tensor;
324 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
325 handle->reg_output = (libxsmm_dnn_tensor*)tensor;
326 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
327 handle->grad_output = (libxsmm_dnn_tensor*)tensor;
328 } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
329 handle->mask = (libxsmm_dnn_tensor*)tensor;
330 } else {
331 /* cannot happen */
332 }
333 } else {
334 status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
335 }
336
337 libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
338 }
339 else {
340 status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
341 }
342
343 return status;
344 }
345
346
libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)347 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
348 libxsmm_dnn_tensor* return_tensor = 0;
349
350 *status = LIBXSMM_DNN_SUCCESS;
351
352 /* check for tensor type */
353 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
354 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
355 (type != LIBXSMM_DNN_POOLING_MASK) ) {
356 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
357 return return_tensor;
358 }
359
360 if (handle != 0) {
361 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
362 return_tensor = handle->reg_input;
363 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
364 return_tensor = handle->grad_input;
365 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
366 return_tensor = handle->reg_output;
367 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
368 return_tensor = handle->grad_output;
369 } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
370 return_tensor = handle->mask;
371 } else {
372 /* cannot happen */
373 }
374 } else {
375 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
376 }
377
378 return return_tensor;
379 }
380
381
libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling * handle,const libxsmm_dnn_tensor_type type)382 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type) {
383 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
384
385 /* check for tensor type */
386 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
387 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
388 (type != LIBXSMM_DNN_POOLING_MASK) ) {
389 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
390 return status;
391 }
392
393 if (handle != 0) {
394 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
395 handle->reg_input = 0;
396 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
397 handle->grad_input = 0;
398 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
399 handle->reg_output = 0;
400 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
401 handle->grad_output = 0;
402 } else if ( type == LIBXSMM_DNN_POOLING_MASK ) {
403 handle->mask = 0;
404 } else {
405 /* cannot happen */
406 }
407 } else {
408 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
409 }
410
411 return status;
412 }
413
414
libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)415 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind,
416 /*unsigned*/int start_thread, /*unsigned*/int tid) {
417 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
418
419 if (0 != handle) {
420 switch (kind) {
421 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
422 switch (handle->desc.buffer_format) {
423 case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
424 status = libxsmm_dnn_pooling_st_fwd_custom( handle, start_thread, tid );
425 } break;
426 default: {
427 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
428 }
429 }
430 } break;
431 case LIBXSMM_DNN_COMPUTE_KIND_BWD: {
432 switch (handle->desc.buffer_format) {
433 case LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM: {
434 status = libxsmm_dnn_pooling_st_bwd_custom( handle, start_thread, tid );
435 } break;
436 default: {
437 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN;
438 }
439 }
440 } break;
441 default: {
442 status = LIBXSMM_DNN_ERR_INVALID_KIND;
443 }
444 }
445 }
446 else {
447 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
448 }
449
450 return status;
451 }
452
453