1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_pooling_forward.h"
12 #include "libxsmm_main.h"
13
14
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid);
19 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid);
20 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid);
21
22
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)23 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
24 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
25 {
26 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
27 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
28 typedef float element_input_type;
29 typedef float element_output_type;
30
31 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
32 # define LIBXSMM_DNN_POOLING_FWD_MAX
33 typedef int element_mask_type;
34 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
35 # undef LIBXSMM_DNN_POOLING_FWD_MAX
36 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
37 # define LIBXSMM_DNN_POOLING_FWD_AVG
38 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
39 # undef LIBXSMM_DNN_POOLING_FWD_AVG
40 } else {
41 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
42 }
43 #else /* should not happen */
44 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
45 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
46 #endif
47 return status;
48 }
49
50
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)51 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
52 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
53 {
54 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
55 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
56 typedef float element_input_type;
57 typedef float element_output_type;
58
59 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
60 # define LIBXSMM_DNN_POOLING_FWD_MAX
61 typedef int element_mask_type;
62 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
63 # undef LIBXSMM_DNN_POOLING_FWD_MAX
64 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
65 # define LIBXSMM_DNN_POOLING_FWD_AVG
66 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
67 # undef LIBXSMM_DNN_POOLING_FWD_AVG
68 } else {
69 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
70 }
71 #else /* should not happen */
72 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
73 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
74 #endif
75 return status;
76 }
77
78
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)79 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
80 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
81 {
82 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
83 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
84 typedef float element_input_type;
85 typedef float element_output_type;
86
87 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
88 # define LIBXSMM_DNN_POOLING_FWD_MAX
89 typedef int element_mask_type;
90 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
91 # undef LIBXSMM_DNN_POOLING_FWD_MAX
92 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
93 # define LIBXSMM_DNN_POOLING_FWD_AVG
94 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
95 # undef LIBXSMM_DNN_POOLING_FWD_AVG
96 } else {
97 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
98 }
99 #else /* should not happen */
100 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
101 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
102 #endif
103 return status;
104 }
105
106
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)107 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
108 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_pooling* handle, int start_thread, int tid)
109 {
110 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
111 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
112 typedef libxsmm_bfloat16 element_input_type;
113 typedef libxsmm_bfloat16 element_output_type;
114
115 # define LIBXSMM_DNN_POOLING_FWD_BF16
116 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
117 # define LIBXSMM_DNN_POOLING_FWD_MAX
118 typedef int element_mask_type;
119 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
120 # undef LIBXSMM_DNN_POOLING_FWD_MAX
121 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
122 # define LIBXSMM_DNN_POOLING_FWD_AVG
123 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
124 # undef LIBXSMM_DNN_POOLING_FWD_AVG
125 } else {
126 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
127 }
128 # undef LIBXSMM_DNN_POOLING_FWD_BF16
129 #else /* should not happen */
130 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
131 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
132 #endif
133 return status;
134 }
135
136
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)137 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
138 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_pooling* handle, int start_thread, int tid)
139 {
140 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
141 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
142 typedef libxsmm_bfloat16 element_input_type;
143 typedef libxsmm_bfloat16 element_output_type;
144
145 # define LIBXSMM_DNN_POOLING_FWD_BF16
146 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
147 # define LIBXSMM_DNN_POOLING_FWD_MAX
148 typedef int element_mask_type;
149 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
150 # undef LIBXSMM_DNN_POOLING_FWD_MAX
151 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
152 # define LIBXSMM_DNN_POOLING_FWD_AVG
153 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
154 # undef LIBXSMM_DNN_POOLING_FWD_AVG
155 } else {
156 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
157 }
158 # undef LIBXSMM_DNN_POOLING_FWD_BF16
159 #else /* should not happen */
160 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
161 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
162 #endif
163 return status;
164 }
165
166
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)167 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
168 libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_pooling* handle, int start_thread, int tid)
169 {
170 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
171 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
172 typedef libxsmm_bfloat16 element_input_type;
173 typedef libxsmm_bfloat16 element_output_type;
174
175 # define LIBXSMM_DNN_POOLING_FWD_BF16
176 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
177 # define LIBXSMM_DNN_POOLING_FWD_MAX
178 typedef int element_mask_type;
179 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
180 # undef LIBXSMM_DNN_POOLING_FWD_MAX
181 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
182 # define LIBXSMM_DNN_POOLING_FWD_AVG
183 # include "template/libxsmm_dnn_pooling_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
184 # undef LIBXSMM_DNN_POOLING_FWD_AVG
185 } else {
186 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
187 }
188 # undef LIBXSMM_DNN_POOLING_FWD_BF16
189 #else /* should not happen */
190 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
191 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
192 #endif
193 return status;
194 }
195
196
libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling * handle,int start_thread,int tid)197 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_custom(libxsmm_dnn_pooling* handle, int start_thread, int tid)
198 {
199 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
200
201 /* check if we have input, output and mask */
202 if ( handle->reg_input == 0 || handle->reg_output == 0 ||
203 ( (handle->mask == 0) && (handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX) ) ) {
204 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
205 return status;
206 }
207
208 /* check if we are on an AVX512 platform */
209 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
210 if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
211 (handle->ofmblock == 16) ) {
212 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
213 LIBXSMM_ASSERT(NULL != handle->mask);
214 status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c16( handle, start_thread, tid);
215 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
216 LIBXSMM_ASSERT(NULL != handle->mask);
217 status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid);
218 } else {
219 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
220 return status;
221 }
222 } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
223 (handle->ofmblock == 32) ) {
224 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
225 LIBXSMM_ASSERT(NULL != handle->mask);
226 status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c32( handle, start_thread, tid);
227 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
228 LIBXSMM_ASSERT(NULL != handle->mask);
229 status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid);
230 } else {
231 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
232 return status;
233 }
234 } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
235 (handle->ofmblock == 64) ) {
236 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
237 LIBXSMM_ASSERT(NULL != handle->mask);
238 status = libxsmm_dnn_pooling_st_fwd_custom_f32_f32_c64( handle, start_thread, tid);
239 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
240 LIBXSMM_ASSERT(NULL != handle->mask);
241 status = libxsmm_dnn_pooling_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid);
242 } else {
243 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
244 return status;
245 }
246 } else
247 #endif
248 {
249 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
250 typedef float element_input_type;
251 typedef float element_output_type;
252
253 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
254 # define LIBXSMM_DNN_POOLING_FWD_MAX
255 typedef int element_mask_type;
256 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
257 # undef LIBXSMM_DNN_POOLING_FWD_MAX
258 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
259 # define LIBXSMM_DNN_POOLING_FWD_AVG
260 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
261 # undef LIBXSMM_DNN_POOLING_FWD_AVG
262 } else {
263 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
264 }
265 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
266 typedef libxsmm_bfloat16 element_input_type;
267 typedef libxsmm_bfloat16 element_output_type;
268
269 # define LIBXSMM_DNN_POOLING_FWD_BF16
270 if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_MAX ) {
271 # define LIBXSMM_DNN_POOLING_FWD_MAX
272 typedef int element_mask_type;
273 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
274 # undef LIBXSMM_DNN_POOLING_FWD_MAX
275 } else if ( handle->desc.pooling_type == LIBXSMM_DNN_POOLING_AVG ) {
276 # define LIBXSMM_DNN_POOLING_FWD_AVG
277 # include "template/libxsmm_dnn_pooling_st_fwd_custom_generic.tpl.c"
278 # undef LIBXSMM_DNN_POOLING_FWD_AVG
279 } else {
280 status = LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING;
281 }
282 # undef LIBXSMM_DNN_POOLING_FWD_BF16
283 } else {
284 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
285 return status;
286 }
287 }
288
289 return status;
290 }
291
292
libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling * handle,int start_thread,int tid)293 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_pooling_st_fwd_nhwc(libxsmm_dnn_pooling* handle, int start_thread, int tid)
294 {
295 libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
296 LIBXSMM_UNUSED( handle );
297 LIBXSMM_UNUSED( start_thread );
298 LIBXSMM_UNUSED( tid );
299 return status;
300 }
301
302