1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_DNN_POOLING_H
12 #define LIBXSMM_DNN_POOLING_H
13 
14 #include "libxsmm_dnn.h"
15 #include "libxsmm_dnn_tensor.h"
16 
17 /** Opaque handles which represents LIBXSMM pooling */
18 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling libxsmm_dnn_pooling;
19 
20 typedef enum libxsmm_dnn_pooling_type {
21   LIBXSMM_DNN_POOLING_MAX = 1,
22   LIBXSMM_DNN_POOLING_AVG = 2
23 } libxsmm_dnn_pooling_type;
24 
25 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling_desc {
26   int N;                                     /* number of images in mini-batch */
27   int C;                                     /* number of input feature maps */
28   int H;                                     /* height of input image */
29   int W;                                     /* width of input image */
30   int R;                                     /* kernel height */
31   int S;                                     /* kernel width */
32   int u;                                     /* vertical stride */
33   int v;                                     /* horizontal stride */
34   int pad_h;                                 /* height of logical padding of input buffer */
35   int pad_w;                                 /* width of logical padding of input buffer */
36   int pad_h_in;                              /* height of physical zero-padding in input buffer */
37   int pad_w_in;                              /* width of physical zero-padding in input buffer */
38   int pad_h_out;                             /* height of physical zero-padding in output buffer */
39   int pad_w_out;                             /* width of physical zero-padding in output buffer */
40   int threads;                               /* number of threads used */
41   libxsmm_dnn_datatype datatype_in;          /* datatypes used for all input related buffer */
42   libxsmm_dnn_datatype datatype_out;         /* datatypes used for all output related buffer */
43   libxsmm_dnn_datatype datatype_mask;        /* datatypes used for the masks */
44   libxsmm_dnn_tensor_format buffer_format;   /* format which is for activation buffers */
45   libxsmm_dnn_pooling_type pooling_type;     /* type of pooling operation */
46 } libxsmm_dnn_pooling_desc;
47 
48 LIBXSMM_API libxsmm_dnn_pooling* libxsmm_dnn_create_pooling(libxsmm_dnn_pooling_desc pooling_desc, libxsmm_dnn_err_t* status);
49 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_pooling(const libxsmm_dnn_pooling* handle);
50 
51 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_pooling_create_tensor_datalayout(const libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);
52 
53 LIBXSMM_API size_t libxsmm_dnn_pooling_get_scratch_size(const libxsmm_dnn_pooling* handle, libxsmm_dnn_err_t* status);
54 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_scratch(libxsmm_dnn_pooling* handle, const void* scratch);
55 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_scratch(libxsmm_dnn_pooling* handle);
56 
57 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_bind_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type);
58 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_pooling_get_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);
59 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_release_tensor(libxsmm_dnn_pooling* handle, const libxsmm_dnn_tensor_type type);
60 
61 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_pooling_execute_st(libxsmm_dnn_pooling* handle, libxsmm_dnn_compute_kind kind,
62   /*unsigned*/int start_thread, /*unsigned*/int tid);
63 
64 #endif /*LIBXSMM_DNN_POOLING_H*/
65 
66