1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_pooling_forward.h"
12 #include "libxsmm_main.h"
13 
14 
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
19 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
20 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
21 
22 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)23 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
24 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
25 {
26   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
27 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
28   typedef float element_input_type;
29   typedef float element_output_type;
30 
31   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
32 # define LIBXSMM_DNN_POOLING_FWD_MAX
33     typedef int element_mask_type;
34 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
35 # undef LIBXSMM_DNN_POOLING_FWD_MAX
36   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
37 # define LIBXSMM_DNN_POOLING_FWD_AVG
38 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
39 # undef LIBXSMM_DNN_POOLING_FWD_AVG
40   } else {
41     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
42   }
43 #else /* should not happen */
44   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
45   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
46 #endif
47   return status;
48 }
49 
50 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)51 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
52 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
53 {
54   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
55 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
56   typedef float element_input_type;
57   typedef float element_output_type;
58 
59   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
60 # define LIBXSMM_DNN_POOLING_FWD_MAX
61     typedef int element_mask_type;
62 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
63 # undef LIBXSMM_DNN_POOLING_FWD_MAX
64   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
65 # define LIBXSMM_DNN_POOLING_FWD_AVG
66 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
67 # undef LIBXSMM_DNN_POOLING_FWD_AVG
68   } else {
69     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
70   }
71 #else /* should not happen */
72   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
73   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
74 #endif
75   return status;
76 }
77 
78 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)79 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
80 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
81 {
82   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
83 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
84   typedef float element_input_type;
85   typedef float element_output_type;
86 
87   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
88 # define LIBXSMM_DNN_POOLING_FWD_MAX
89     typedef int element_mask_type;
90 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
91 # undef LIBXSMM_DNN_POOLING_FWD_MAX
92   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
93 # define LIBXSMM_DNN_POOLING_FWD_AVG
94 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
95 # undef LIBXSMM_DNN_POOLING_FWD_AVG
96   } else {
97     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
98   }
99 #else /* should not happen */
100   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
101   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
102 #endif
103   return status;
104 }
105 
106 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)107 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
108 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
109 {
110   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
111 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
112   typedef libxsmm_bfloat16 element_input_type;
113   typedef libxsmm_bfloat16 element_output_type;
114 
115 # define LIBXSMM_DNN_POOLING_FWD_BF16
116   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
117 # define LIBXSMM_DNN_POOLING_FWD_MAX
118     typedef int element_mask_type;
119 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
120 # undef LIBXSMM_DNN_POOLING_FWD_MAX
121   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
122 # define LIBXSMM_DNN_POOLING_FWD_AVG
123 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
124 # undef LIBXSMM_DNN_POOLING_FWD_AVG
125   } else {
126     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
127   }
128 # undef LIBXSMM_DNN_POOLING_FWD_BF16
129 #else /* should not happen */
130   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
131   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
132 #endif
133   return status;
134 }
135 
136 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)137 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
138 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
139 {
140   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
141 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
142   typedef libxsmm_bfloat16 element_input_type;
143   typedef libxsmm_bfloat16 element_output_type;
144 
145 # define LIBXSMM_DNN_POOLING_FWD_BF16
146   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
147 # define LIBXSMM_DNN_POOLING_FWD_MAX
148     typedef int element_mask_type;
149 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
150 # undef LIBXSMM_DNN_POOLING_FWD_MAX
151   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
152 # define LIBXSMM_DNN_POOLING_FWD_AVG
153 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
154 # undef LIBXSMM_DNN_POOLING_FWD_AVG
155   } else {
156     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
157   }
158 # undef LIBXSMM_DNN_POOLING_FWD_BF16
159 #else /* should not happen */
160   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
161   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
162 #endif
163   return status;
164 }
165 
166 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)167 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
168 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
169 {
170   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
171 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
172   typedef libxsmm_bfloat16 element_input_type;
173   typedef libxsmm_bfloat16 element_output_type;
174 
175 # define LIBXSMM_DNN_POOLING_FWD_BF16
176   if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
177 # define LIBXSMM_DNN_POOLING_FWD_MAX
178     typedef int element_mask_type;
179 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
180 # undef LIBXSMM_DNN_POOLING_FWD_MAX
181   } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
182 # define LIBXSMM_DNN_POOLING_FWD_AVG
183 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
184 # undef LIBXSMM_DNN_POOLING_FWD_AVG
185   } else {
186     status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
187   }
188 # undef LIBXSMM_DNN_POOLING_FWD_BF16
189 #else /* should not happen */
190   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
191   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
192 #endif
193   return status;
194 }
195 
196 
libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling * handle,int start_thread,int tid)197 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid)
198 {
199   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
200 
201   /* check if we have input, output and mask */
202   if ( handle->reg_input == 0 || handle->reg_output == 0 ||
203        ( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) {
204     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
205     return status;
206   }
207 
208   /* check if we are on an AVX512 platform */
209 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
210   if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
211        (handle->ofmblock == 16) ) {
212     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
213       LIBXSMM_ASSERT(NULL != handle->mask);
214       status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16( handle, start_thread, tid);
215     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
216       LIBXSMM_ASSERT(NULL != handle->mask);
217       status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid);
218     } else {
219       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
220       return status;
221     }
222   } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
223        (handle->ofmblock == 32) ) {
224     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
225       LIBXSMM_ASSERT(NULL != handle->mask);
226       status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32( handle, start_thread, tid);
227     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
228       LIBXSMM_ASSERT(NULL != handle->mask);
229       status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid);
230     } else {
231       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
232       return status;
233     }
234   } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
235        (handle->ofmblock == 64) ) {
236     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
237       LIBXSMM_ASSERT(NULL != handle->mask);
238       status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64( handle, start_thread, tid);
239     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
240       LIBXSMM_ASSERT(NULL != handle->mask);
241       status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid);
242     } else {
243       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
244       return status;
245     }
246   } else
247 #endif
248   {
249     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
250       typedef float element_input_type;
251       typedef float element_output_type;
252 
253       if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
254 # define LIBXSMM_DNN_POOLING_FWD_MAX
255         typedef int element_mask_type;
256 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
257 # undef LIBXSMM_DNN_POOLING_FWD_MAX
258       } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
259 # define LIBXSMM_DNN_POOLING_FWD_AVG
260 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
261 # undef LIBXSMM_DNN_POOLING_FWD_AVG
262       } else {
263         status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
264       }
265     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
266       typedef libxsmm_bfloat16 element_input_type;
267       typedef libxsmm_bfloat16 element_output_type;
268 
269 # define LIBXSMM_DNN_POOLING_FWD_BF16
270       if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
271 # define LIBXSMM_DNN_POOLING_FWD_MAX
272         typedef int element_mask_type;
273 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
274 # undef LIBXSMM_DNN_POOLING_FWD_MAX
275       } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
276 # define LIBXSMM_DNN_POOLING_FWD_AVG
277 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
278 # undef LIBXSMM_DNN_POOLING_FWD_AVG
279       } else {
280         status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
281       }
282 # undef LIBXSMM_DNN_POOLING_FWD_BF16
283     } else {
284       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
285       return status;
286     }
287   }
288 
289   return status;
290 }
291 
292 
libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling * handle,int start_thread,int tid)293 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid)
294 {
295   libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
296   LIBXSMM_UNUSED( handle );
297   LIBXSMM_UNUSED( start_thread );
298   LIBXSMM_UNUSED( tid );
299   return status;
300 }
301 
302